|
| 1 | +package com.redislabs.redisai; |
| 2 | + |
| 3 | +import com.redislabs.redisai.exceptions.JRedisAIRunTimeException; |
| 4 | +import java.util.ArrayList; |
| 5 | +import java.util.List; |
| 6 | +import redis.clients.jedis.Protocol; |
| 7 | +import redis.clients.jedis.util.SafeEncoder; |
| 8 | + |
| 9 | +/** Direct mapping to RedisAI Model */ |
| 10 | +public class Model { |
| 11 | + private Backend backend; |
| 12 | + private Device device; |
| 13 | + private String[] inputs; |
| 14 | + private String[] outputs; |
| 15 | + private byte[] blob; |
| 16 | + private String tag; |
| 17 | + private long batchSize; |
| 18 | + private long minBatchSize; |
| 19 | + |
| 20 | + /** |
| 21 | + * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX |
| 22 | + * @param device - the device that will execute the model. can be of CPU or GPU |
| 23 | + * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow |
| 24 | + * models) |
| 25 | + * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow |
| 26 | + * models) |
| 27 | + * @param blob - the Protobuf-serialized model |
| 28 | + */ |
| 29 | + public Model(Backend backend, Device device, String[] inputs, String[] outputs, byte[] blob) { |
| 30 | + this(backend, device, inputs, outputs, blob, 0, 0); |
| 31 | + } |
| 32 | + |
| 33 | + /** |
| 34 | + * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX |
| 35 | + * @param device - the device that will execute the model. can be of CPU or GPU |
| 36 | + * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow |
| 37 | + * models) |
| 38 | + * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow |
| 39 | + * models) |
| 40 | + * @param blob - the Protobuf-serialized model |
| 41 | + * @param batchSize - when provided with an batchsize that is greater than 0, the engine will |
| 42 | + * batch incoming requests from multiple clients that use the model with input tensors of the |
| 43 | + * same shape. |
| 44 | + * @param minBatchSize - when provided with an minbatchsize that is greater than 0, the engine |
| 45 | + * will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize |
| 46 | + */ |
| 47 | + public Model( |
| 48 | + Backend backend, |
| 49 | + Device device, |
| 50 | + String[] inputs, |
| 51 | + String[] outputs, |
| 52 | + byte[] blob, |
| 53 | + long batchSize, |
| 54 | + long minBatchSize) { |
| 55 | + this.backend = backend; |
| 56 | + this.device = device; |
| 57 | + this.inputs = inputs; |
| 58 | + this.outputs = outputs; |
| 59 | + this.blob = blob; |
| 60 | + this.tag = null; |
| 61 | + this.batchSize = batchSize; |
| 62 | + this.minBatchSize = minBatchSize; |
| 63 | + } |
| 64 | + |
| 65 | + public static Model createModelFromRespReply(List<?> reply) { |
| 66 | + Model model = null; |
| 67 | + Backend backend = null; |
| 68 | + Device device = null; |
| 69 | + String tag = null; |
| 70 | + byte[] blob = null; |
| 71 | + long batchsize = 0; |
| 72 | + long minbatchsize = 0; |
| 73 | + String[] inputs = new String[0]; |
| 74 | + String[] outputs = new String[0]; |
| 75 | + for (int i = 0; i < reply.size(); i += 2) { |
| 76 | + String arrayKey = SafeEncoder.encode((byte[]) reply.get(i)); |
| 77 | + switch (arrayKey) { |
| 78 | + case "backend": |
| 79 | + String backendString = SafeEncoder.encode((byte[]) reply.get(i + 1)); |
| 80 | + backend = Backend.valueOf(backendString); |
| 81 | + if (backend == null) { |
| 82 | + throw new JRedisAIRunTimeException("Unrecognized backend: " + backendString); |
| 83 | + } |
| 84 | + break; |
| 85 | + case "device": |
| 86 | + String deviceString = SafeEncoder.encode((byte[]) reply.get(i + 1)); |
| 87 | + device = Device.valueOf(deviceString); |
| 88 | + if (device == null) { |
| 89 | + throw new JRedisAIRunTimeException("Unrecognized device: " + deviceString); |
| 90 | + } |
| 91 | + break; |
| 92 | + case "tag": |
| 93 | + tag = SafeEncoder.encode((byte[]) reply.get(i + 1)); |
| 94 | + break; |
| 95 | + case "blob": |
| 96 | + blob = (byte[]) reply.get(i + 1); |
| 97 | + break; |
| 98 | + case "batchsize": |
| 99 | + batchsize = (Long) reply.get(i + 1); |
| 100 | + break; |
| 101 | + case "minbatchsize": |
| 102 | + minbatchsize = (Long) reply.get(i + 1); |
| 103 | + break; |
| 104 | + case "inputs": |
| 105 | + List<byte[]> inputsEncoded = (List<byte[]>) reply.get(i + 1); |
| 106 | + if (inputsEncoded.size() > 0) { |
| 107 | + inputs = new String[inputsEncoded.size()]; |
| 108 | + for (int j = 0; j < inputsEncoded.size(); j++) { |
| 109 | + inputs[j] = SafeEncoder.encode(inputsEncoded.get(j)); |
| 110 | + } |
| 111 | + } |
| 112 | + break; |
| 113 | + case "outputs": |
| 114 | + List<byte[]> outputsEncoded = (List<byte[]>) reply.get(i + 1); |
| 115 | + if (outputsEncoded.size() > 0) { |
| 116 | + outputs = new String[outputsEncoded.size()]; |
| 117 | + for (int j = 0; j < outputsEncoded.size(); j++) { |
| 118 | + outputs[j] = SafeEncoder.encode(outputsEncoded.get(j)); |
| 119 | + } |
| 120 | + } |
| 121 | + break; |
| 122 | + default: |
| 123 | + break; |
| 124 | + } |
| 125 | + } |
| 126 | + if (backend == null || device == null || blob == null) { |
| 127 | + throw new JRedisAIRunTimeException( |
| 128 | + "AI.MODELGET reply did not contained all elements to build the model"); |
| 129 | + } |
| 130 | + model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize); |
| 131 | + if (tag != null) { |
| 132 | + model.setTag(tag); |
| 133 | + } |
| 134 | + return model; |
| 135 | + } |
| 136 | + |
| 137 | + public String getTag() { |
| 138 | + return tag; |
| 139 | + } |
| 140 | + |
| 141 | + public void setTag(String tag) { |
| 142 | + this.tag = tag; |
| 143 | + } |
| 144 | + |
| 145 | + public byte[] getBlob() { |
| 146 | + return blob; |
| 147 | + } |
| 148 | + |
| 149 | + public void setBlob(byte[] blob) { |
| 150 | + this.blob = blob; |
| 151 | + } |
| 152 | + |
| 153 | + public String[] getOutputs() { |
| 154 | + return outputs; |
| 155 | + } |
| 156 | + |
| 157 | + public void setOutputs(String[] outputs) { |
| 158 | + this.outputs = outputs; |
| 159 | + } |
| 160 | + |
| 161 | + public String[] getInputs() { |
| 162 | + return inputs; |
| 163 | + } |
| 164 | + |
| 165 | + public void setInputs(String[] inputs) { |
| 166 | + this.inputs = inputs; |
| 167 | + } |
| 168 | + |
| 169 | + public Device getDevice() { |
| 170 | + return device; |
| 171 | + } |
| 172 | + |
| 173 | + public void setDevice(Device device) { |
| 174 | + this.device = device; |
| 175 | + } |
| 176 | + |
| 177 | + public Backend getBackend() { |
| 178 | + return backend; |
| 179 | + } |
| 180 | + |
| 181 | + public void setBackend(Backend backend) { |
| 182 | + this.backend = backend; |
| 183 | + } |
| 184 | + |
| 185 | + public long getBatchSize() { |
| 186 | + return batchSize; |
| 187 | + } |
| 188 | + |
| 189 | + public void setBatchSize(long batchsize) { |
| 190 | + this.batchSize = batchsize; |
| 191 | + } |
| 192 | + |
| 193 | + public long getMinBatchSize() { |
| 194 | + return minBatchSize; |
| 195 | + } |
| 196 | + |
| 197 | + public void setMinBatchSize(long minbatchsize) { |
| 198 | + this.minBatchSize = minbatchsize; |
| 199 | + } |
| 200 | + |
| 201 | + /** |
| 202 | + * Encodes the current model properties into an AI.MODELSET command to be store in RedisAI Server |
| 203 | + * |
| 204 | + * @param key name of key to store the Model |
| 205 | + * @return |
| 206 | + */ |
| 207 | + protected List<byte[]> getModelSetCommandBytes(String key) { |
| 208 | + List<byte[]> args = new ArrayList<>(); |
| 209 | + args.add(SafeEncoder.encode(key)); |
| 210 | + args.add(backend.getRaw()); |
| 211 | + args.add(device.getRaw()); |
| 212 | + if (tag != null) { |
| 213 | + args.add(Keyword.TAG.getRaw()); |
| 214 | + args.add(SafeEncoder.encode(tag)); |
| 215 | + } |
| 216 | + if (batchSize > 0) { |
| 217 | + args.add(Keyword.BATCHSIZE.getRaw()); |
| 218 | + args.add(Protocol.toByteArray(batchSize)); |
| 219 | + if (minBatchSize > 0) { |
| 220 | + args.add(Keyword.MINBATCHSIZE.getRaw()); |
| 221 | + args.add(Protocol.toByteArray(minBatchSize)); |
| 222 | + } |
| 223 | + } |
| 224 | + args.add(Keyword.INPUTS.getRaw()); |
| 225 | + for (String input : inputs) { |
| 226 | + args.add(SafeEncoder.encode(input)); |
| 227 | + } |
| 228 | + args.add(Keyword.OUTPUTS.getRaw()); |
| 229 | + for (String output : outputs) { |
| 230 | + args.add(SafeEncoder.encode(output)); |
| 231 | + } |
| 232 | + args.add(Keyword.BLOB.getRaw()); |
| 233 | + args.add(blob); |
| 234 | + return args; |
| 235 | + } |
| 236 | +} |
0 commit comments