Skip to content

Commit 90c8731

Browse files
Added support for MODELGET, MODELDEL, SCRIPTGEL, SCRIPTDEL (#15)
* [add] support for MODELGET, MODELDEL, SCRIPTGEL, SCRIPTDEL * [add] added test for setTensor variations * [add] extended negative testing * [fix] following lowerCamelCase per review * [fix] fixing Variable script may be null here because of assignement on Model and Script. * [add] moving from ArrayList<byte[]>... signature to generic List<byte[]>... one * [add] Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET. [fix] fixes per PR review * [fix] fixed long dim to int dim per PR review * [add] added a Model constructor that accepts batchsize and minbatchsize as parameters * [add] Added script constructor given the device string and the Path containing the script. ensuring same behaviour across different command method wrappers for the same redisai command * [add] throwing an error on model/tensor/script creation if reply does not contain all elements. Added negative testing for resp parsing * [fix] per PR review
1 parent 71c110c commit 90c8731

34 files changed

+15409
-204
lines changed

src/main/java/com/redislabs/redisai/Backend.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
public enum Backend implements ProtocolCommand {
77
TF,
8-
TORCH;
8+
TORCH,
9+
TFLITE,
10+
ONNX;
911

1012
private final byte[] raw;
1113

src/main/java/com/redislabs/redisai/Command.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ public enum Command implements ProtocolCommand {
88
TENSOR_SET("AI.TENSORSET"),
99
MODEL_GET("AI.MODELGET"),
1010
MODEL_SET("AI.MODELSET"),
11+
MODEL_DEL("AI.MODELDEL"),
1112
MODEL_RUN("AI.MODELRUN"),
12-
SCRIPT_GET("AI.SCRIPTGET"),
1313
SCRIPT_SET("AI.SCRIPTSET"),
14+
SCRIPT_GET("AI.SCRIPTGET"),
15+
SCRIPT_DEL("AI.SCRIPTDEL"),
1416
SCRIPT_RUN("AI.SCRIPTRUN"),
17+
// TODO: support AI.DAGRUN
18+
DAGRUN("AI.DAGRUN"),
19+
// TODO: support AI.DAGRUN_RO
20+
DAGRUN_RO("AI.DAGRUN_RO"),
1521
INFO("AI.INFO"),
1622
CONFIG("AI.CONFIG");
1723

src/main/java/com/redislabs/redisai/DataType.java

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -88,43 +88,8 @@ protected Object toObject(List<byte[]> data) {
8888
}
8989
return values;
9090
}
91-
},
92-
STRING {
93-
@Override
94-
public List<byte[]> toByteArray(Object obj) {
95-
byte[] values = (byte[]) obj;
96-
List<byte[]> res = new ArrayList<>(values.length);
97-
for (byte value : values) {
98-
res.add(Protocol.toByteArray(value));
99-
}
100-
return res;
101-
}
102-
103-
@Override
104-
protected Object toObject(List<byte[]> data) {
105-
return data;
106-
}
107-
},
108-
BOOL {
109-
@Override
110-
public List<byte[]> toByteArray(Object obj) {
111-
boolean[] values = (boolean[]) obj;
112-
List<byte[]> res = new ArrayList<>(values.length);
113-
for (boolean value : values) {
114-
res.add(Protocol.toByteArray(value));
115-
}
116-
return res;
117-
}
118-
119-
@Override
120-
protected Object toObject(List<byte[]> data) {
121-
// TODO Auto-generated method stub
122-
return null;
123-
}
12491
};
12592

126-
private final byte[] raw;
127-
12893
private static final HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
12994

13095
static {
@@ -136,12 +101,10 @@ protected Object toObject(List<byte[]> data) {
136101
classDataTypes.put(Float.class, DataType.FLOAT);
137102
classDataTypes.put(double.class, DataType.DOUBLE);
138103
classDataTypes.put(Double.class, DataType.DOUBLE);
139-
classDataTypes.put(byte.class, DataType.STRING);
140-
classDataTypes.put(Byte.class, DataType.STRING);
141-
classDataTypes.put(boolean.class, DataType.BOOL);
142-
classDataTypes.put(Boolean.class, DataType.BOOL);
143104
}
144105

106+
private final byte[] raw;
107+
145108
DataType() {
146109
raw = SafeEncoder.encode(this.name());
147110
}
@@ -163,18 +126,6 @@ static DataType getDataTypefromString(String dtypeRaw) {
163126
return dt;
164127
}
165128

166-
protected abstract List<byte[]> toByteArray(Object obj);
167-
168-
protected abstract Object toObject(List<byte[]> data);
169-
170-
public byte[] getRaw() {
171-
return raw;
172-
}
173-
174-
public List<byte[]> toByteArray(Object obj, int[] dimensions) {
175-
return toByteArray(obj, dimensions, 0, this);
176-
}
177-
178129
/** The class for the data type to which Java object o corresponds. */
179130
public static DataType baseObjType(Object o) {
180131
Class<?> c = o.getClass();
@@ -188,10 +139,10 @@ public static DataType baseObjType(Object o) {
188139
throw new IllegalArgumentException("cannot create Tensors of type " + c.getName());
189140
}
190141

191-
private static List<byte[]> toByteArray(Object obj, int[] dimensions, int dim, DataType type) {
142+
private static List<byte[]> toByteArray(Object obj, long[] dimensions, int dim, DataType type) {
192143
ArrayList<byte[]> res = new ArrayList<>();
193-
if (dimensions.length > dim + 1) {
194-
int dimension = dimensions[dim++];
144+
if (dimensions.length - 1 > dim) {
145+
long dimension = dimensions[dim++];
195146
for (int i = 0; i < dimension; ++i) {
196147
Object value = Array.get(obj, i);
197148
res.addAll(toByteArray(value, dimensions, dim, type));
@@ -201,4 +152,16 @@ private static List<byte[]> toByteArray(Object obj, int[] dimensions, int dim, D
201152
}
202153
return res;
203154
}
155+
156+
protected abstract List<byte[]> toByteArray(Object obj);
157+
158+
protected abstract Object toObject(List<byte[]> data);
159+
160+
public byte[] getRaw() {
161+
return raw;
162+
}
163+
164+
public List<byte[]> toByteArray(Object obj, long[] dimensions) {
165+
return toByteArray(obj, dimensions, 0, this);
166+
}
204167
}

src/main/java/com/redislabs/redisai/Keyword.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ public enum Keyword implements ProtocolCommand {
1212
BLOB,
1313
SOURCE,
1414
RESETSTAT,
15+
TAG,
16+
BATCHSIZE,
17+
MINBATCHSIZE,
1518
BACKENDSPATH,
1619
LOADBACKEND;
1720

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package com.redislabs.redisai;
2+
3+
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import redis.clients.jedis.Protocol;
7+
import redis.clients.jedis.util.SafeEncoder;
8+
9+
/** Direct mapping to RedisAI Model */
10+
public class Model {
11+
private Backend backend;
12+
private Device device;
13+
private String[] inputs;
14+
private String[] outputs;
15+
private byte[] blob;
16+
private String tag;
17+
private long batchSize;
18+
private long minBatchSize;
19+
20+
/**
21+
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
22+
* @param device - the device that will execute the model. can be of CPU or GPU
23+
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
24+
* models)
25+
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
26+
* models)
27+
* @param blob - the Protobuf-serialized model
28+
*/
29+
public Model(Backend backend, Device device, String[] inputs, String[] outputs, byte[] blob) {
30+
this(backend, device, inputs, outputs, blob, 0, 0);
31+
}
32+
33+
/**
34+
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
35+
* @param device - the device that will execute the model. can be of CPU or GPU
36+
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
37+
* models)
38+
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
39+
* models)
40+
* @param blob - the Protobuf-serialized model
41+
* @param batchSize - when provided with an batchsize that is greater than 0, the engine will
42+
* batch incoming requests from multiple clients that use the model with input tensors of the
43+
* same shape.
44+
* @param minBatchSize - when provided with an minbatchsize that is greater than 0, the engine
45+
* will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
46+
*/
47+
public Model(
48+
Backend backend,
49+
Device device,
50+
String[] inputs,
51+
String[] outputs,
52+
byte[] blob,
53+
long batchSize,
54+
long minBatchSize) {
55+
this.backend = backend;
56+
this.device = device;
57+
this.inputs = inputs;
58+
this.outputs = outputs;
59+
this.blob = blob;
60+
this.tag = null;
61+
this.batchSize = batchSize;
62+
this.minBatchSize = minBatchSize;
63+
}
64+
65+
public static Model createModelFromRespReply(List<?> reply) {
66+
Model model = null;
67+
Backend backend = null;
68+
Device device = null;
69+
String tag = null;
70+
byte[] blob = null;
71+
long batchsize = 0;
72+
long minbatchsize = 0;
73+
String[] inputs = new String[0];
74+
String[] outputs = new String[0];
75+
for (int i = 0; i < reply.size(); i += 2) {
76+
String arrayKey = SafeEncoder.encode((byte[]) reply.get(i));
77+
switch (arrayKey) {
78+
case "backend":
79+
String backendString = SafeEncoder.encode((byte[]) reply.get(i + 1));
80+
backend = Backend.valueOf(backendString);
81+
if (backend == null) {
82+
throw new JRedisAIRunTimeException("Unrecognized backend: " + backendString);
83+
}
84+
break;
85+
case "device":
86+
String deviceString = SafeEncoder.encode((byte[]) reply.get(i + 1));
87+
device = Device.valueOf(deviceString);
88+
if (device == null) {
89+
throw new JRedisAIRunTimeException("Unrecognized device: " + deviceString);
90+
}
91+
break;
92+
case "tag":
93+
tag = SafeEncoder.encode((byte[]) reply.get(i + 1));
94+
break;
95+
case "blob":
96+
blob = (byte[]) reply.get(i + 1);
97+
break;
98+
case "batchsize":
99+
batchsize = (Long) reply.get(i + 1);
100+
break;
101+
case "minbatchsize":
102+
minbatchsize = (Long) reply.get(i + 1);
103+
break;
104+
case "inputs":
105+
List<byte[]> inputsEncoded = (List<byte[]>) reply.get(i + 1);
106+
if (inputsEncoded.size() > 0) {
107+
inputs = new String[inputsEncoded.size()];
108+
for (int j = 0; j < inputsEncoded.size(); j++) {
109+
inputs[j] = SafeEncoder.encode(inputsEncoded.get(j));
110+
}
111+
}
112+
break;
113+
case "outputs":
114+
List<byte[]> outputsEncoded = (List<byte[]>) reply.get(i + 1);
115+
if (outputsEncoded.size() > 0) {
116+
outputs = new String[outputsEncoded.size()];
117+
for (int j = 0; j < outputsEncoded.size(); j++) {
118+
outputs[j] = SafeEncoder.encode(outputsEncoded.get(j));
119+
}
120+
}
121+
break;
122+
default:
123+
break;
124+
}
125+
}
126+
if (backend == null || device == null || blob == null) {
127+
throw new JRedisAIRunTimeException(
128+
"AI.MODELGET reply did not contained all elements to build the model");
129+
}
130+
model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
131+
if (tag != null) {
132+
model.setTag(tag);
133+
}
134+
return model;
135+
}
136+
137+
public String getTag() {
138+
return tag;
139+
}
140+
141+
public void setTag(String tag) {
142+
this.tag = tag;
143+
}
144+
145+
public byte[] getBlob() {
146+
return blob;
147+
}
148+
149+
public void setBlob(byte[] blob) {
150+
this.blob = blob;
151+
}
152+
153+
public String[] getOutputs() {
154+
return outputs;
155+
}
156+
157+
public void setOutputs(String[] outputs) {
158+
this.outputs = outputs;
159+
}
160+
161+
public String[] getInputs() {
162+
return inputs;
163+
}
164+
165+
public void setInputs(String[] inputs) {
166+
this.inputs = inputs;
167+
}
168+
169+
public Device getDevice() {
170+
return device;
171+
}
172+
173+
public void setDevice(Device device) {
174+
this.device = device;
175+
}
176+
177+
public Backend getBackend() {
178+
return backend;
179+
}
180+
181+
public void setBackend(Backend backend) {
182+
this.backend = backend;
183+
}
184+
185+
public long getBatchSize() {
186+
return batchSize;
187+
}
188+
189+
public void setBatchSize(long batchsize) {
190+
this.batchSize = batchsize;
191+
}
192+
193+
public long getMinBatchSize() {
194+
return minBatchSize;
195+
}
196+
197+
public void setMinBatchSize(long minbatchsize) {
198+
this.minBatchSize = minbatchsize;
199+
}
200+
201+
/**
202+
* Encodes the current model properties into an AI.MODELSET command to be store in RedisAI Server
203+
*
204+
* @param key name of key to store the Model
205+
* @return
206+
*/
207+
protected List<byte[]> getModelSetCommandBytes(String key) {
208+
List<byte[]> args = new ArrayList<>();
209+
args.add(SafeEncoder.encode(key));
210+
args.add(backend.getRaw());
211+
args.add(device.getRaw());
212+
if (tag != null) {
213+
args.add(Keyword.TAG.getRaw());
214+
args.add(SafeEncoder.encode(tag));
215+
}
216+
if (batchSize > 0) {
217+
args.add(Keyword.BATCHSIZE.getRaw());
218+
args.add(Protocol.toByteArray(batchSize));
219+
if (minBatchSize > 0) {
220+
args.add(Keyword.MINBATCHSIZE.getRaw());
221+
args.add(Protocol.toByteArray(minBatchSize));
222+
}
223+
}
224+
args.add(Keyword.INPUTS.getRaw());
225+
for (String input : inputs) {
226+
args.add(SafeEncoder.encode(input));
227+
}
228+
args.add(Keyword.OUTPUTS.getRaw());
229+
for (String output : outputs) {
230+
args.add(SafeEncoder.encode(output));
231+
}
232+
args.add(Keyword.BLOB.getRaw());
233+
args.add(blob);
234+
return args;
235+
}
236+
}

0 commit comments

Comments
 (0)