Skip to content

Commit edf0a94

Browse files
committed
Moving to eager mode
1 parent ffd4084 commit edf0a94

File tree

1 file changed

+121
-46
lines changed

1 file changed

+121
-46
lines changed

src/backends/tensorflow.c

Lines changed: 121 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include "model.h"
77

88
#include "tensorflow/c/c_api.h"
9+
#include "tensorflow/c/eager/c_api.h"
10+
11+
#define RAI_TF_FN_NAME "rai_tf_forward"
912

1013
int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
1114
get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc));
@@ -224,19 +227,15 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
224227
RAI_SetError(error, RAI_EMODELIMPORT, "ERR unsupported device");
225228
}
226229

227-
TF_Graph *model = TF_NewGraph();
230+
TF_Graph *graph = TF_NewGraph();
231+
TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions();
228232
TF_Status *status = TF_NewStatus();
229233
TF_Buffer *tfbuffer = TF_NewBuffer();
230-
TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions();
231-
TF_Status *optionsStatus = NULL;
232-
TF_SessionOptions *sessionOptions = NULL;
233-
TF_Status *sessionStatus = NULL;
234-
TF_Session *session = NULL;
235234

236235
tfbuffer->length = modellen;
237236
tfbuffer->data = modeldef;
238237

239-
TF_GraphImportGraphDef(model, tfbuffer, options, status);
238+
TF_GraphImportGraphDef(graph, tfbuffer, options, status);
240239

241240
if (TF_GetCode(status) != TF_OK) {
242241
char *errorMessage = RedisModule_Strdup(TF_Message(status));
@@ -246,26 +245,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
246245
}
247246

248247
for (size_t i = 0; i < ninputs; ++i) {
249-
TF_Operation *oper = TF_GraphOperationByName(model, inputs[i]);
248+
TF_Operation *oper = TF_GraphOperationByName(graph, inputs[i]);
250249
if (oper == NULL) {
251250
size_t len = strlen(inputs[i]);
252251
char *msg = RedisModule_Calloc(60 + len, sizeof(*msg));
253252
sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", inputs[i]);
254253
RAI_SetError(error, RAI_EMODELIMPORT, msg);
255254
RedisModule_Free(msg);
256-
goto cleanup;
255+
return NULL;
257256
}
258257
}
259258

260259
for (size_t i = 0; i < noutputs; ++i) {
261-
TF_Operation *oper = TF_GraphOperationByName(model, outputs[i]);
260+
TF_Operation *oper = TF_GraphOperationByName(graph, outputs[i]);
262261
if (oper == NULL) {
263262
size_t len = strlen(outputs[i]);
264263
char *msg = RedisModule_Calloc(60 + len, sizeof(*msg));
265264
sprintf(msg, "ERR Output node named \"%s\" not found in TF graph", outputs[i]);
266265
RAI_SetError(error, RAI_EMODELIMPORT, msg);
267266
RedisModule_Free(msg);
268-
goto cleanup;
267+
return NULL;
269268
}
270269
}
271270

@@ -276,6 +275,65 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
276275
TF_DeleteStatus(status);
277276
status = NULL;
278277

278+
TF_Output tf_inputs[ninputs];
279+
TF_Output tf_outputs[noutputs];
280+
281+
for (size_t i = 0; i < ninputs; ++i) {
282+
TF_Output port;
283+
port.oper = TF_GraphOperationByName(graph, inputs[i]);
284+
port.index = 0;
285+
if (port.oper == NULL) {
286+
return NULL;
287+
}
288+
tf_inputs[i] = port;
289+
}
290+
291+
for (size_t i = 0; i < noutputs; ++i) {
292+
TF_Output port;
293+
port.oper = TF_GraphOperationByName(graph, outputs[i]);
294+
port.index = 0;
295+
if (port.oper == NULL) {
296+
return NULL;
297+
}
298+
tf_outputs[i] = port;
299+
}
300+
301+
TF_Function *function = TF_GraphToFunction(
302+
graph, // fn_body
303+
RAI_TF_FN_NAME, 0, // fn_name, append_hash_to_fn_name,
304+
-1, NULL, // num_opers, opers
305+
ninputs, tf_inputs, // ninputs, inputs,
306+
noutputs, tf_outputs, // noutputs, outputs
307+
outputs, // output_names,
308+
NULL, // opts
309+
"", // description
310+
status // status
311+
);
312+
// TODO EAGER
313+
// check status and return error
314+
315+
TFE_ContextOptions *context_opts = TFE_NewContextOptions();
316+
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
317+
// TFE_ContextOptionsSetAsync(context_opts, 0);
318+
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
319+
320+
TFE_Context *context = TFE_NewContext(context_opts, status);
321+
// TODO EAGER
322+
// check status and return error
323+
324+
TFE_ContextAddFunction(context, function, status);
325+
// TODO EAGER
326+
// check status and return error
327+
328+
TFE_DeleteContextOptions(context_opts);
329+
TFE_DeleteContext(context);
330+
331+
#if 0
332+
TF_Status *optionsStatus = NULL;
333+
TF_SessionOptions *sessionOptions = NULL;
334+
TF_Status *sessionStatus = NULL;
335+
TF_Session *session = NULL;
336+
279337
optionsStatus = TF_NewStatus();
280338
sessionOptions = TF_NewSessionOptions();
281339

@@ -341,7 +399,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
341399
optionsStatus = NULL;
342400

343401
sessionStatus = TF_NewStatus();
344-
session = TF_NewSession(model, sessionOptions, sessionStatus);
402+
session = TF_NewSession(graph, sessionOptions, sessionStatus);
345403

346404
TF_Status *deviceListStatus = TF_NewStatus();
347405
TF_DeviceList *deviceList = TF_SessionListDevices(session, deviceListStatus);
@@ -371,6 +429,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
371429

372430
TF_DeleteSessionOptions(sessionOptions);
373431
TF_DeleteStatus(sessionStatus);
432+
#endif
374433

375434
char **inputs_ = array_new(char *, ninputs);
376435
for (long long i = 0; i < ninputs; i++) {
@@ -386,8 +445,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
386445
memcpy(buffer, modeldef, modellen);
387446

388447
RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret));
389-
ret->model = model;
390-
ret->session = session;
448+
ret->model = graph;
449+
ret->session = context;
391450
ret->backend = backend;
392451
ret->devicestr = RedisModule_Strdup(devicestr);
393452
ret->ninputs = ninputs;
@@ -402,22 +461,23 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
402461
return ret;
403462

404463
cleanup:
405-
TF_DeleteGraph(model);
464+
TF_DeleteGraph(graph);
406465
if (options)
407466
TF_DeleteImportGraphDefOptions(options);
408467
if (tfbuffer)
409468
TF_DeleteBuffer(tfbuffer);
410469
if (status)
411470
TF_DeleteStatus(status);
412-
if (sessionOptions)
413-
TF_DeleteSessionOptions(sessionOptions);
414-
if (sessionStatus)
415-
TF_DeleteStatus(sessionStatus);
471+
// if (sessionOptions)
472+
// TF_DeleteSessionOptions(sessionOptions);
473+
// if (sessionStatus)
474+
// TF_DeleteStatus(sessionStatus);
416475
return NULL;
417476
}
418477

419478
void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
420479
TF_Status *status = TF_NewStatus();
480+
#if 0
421481
TF_CloseSession(model->session, status);
422482

423483
if (TF_GetCode(status) != TF_OK) {
@@ -426,12 +486,14 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
426486
}
427487

428488
TF_DeleteSession(model->session, status);
489+
#endif
490+
TFE_DeleteContext(model->session);
429491
model->session = NULL;
430492

431-
if (TF_GetCode(status) != TF_OK) {
432-
RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
433-
return;
434-
}
493+
// if (TF_GetCode(status) != TF_OK) {
494+
// RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
495+
// return;
496+
// }
435497

436498
TF_DeleteGraph(model->model);
437499
model->model = NULL;
@@ -458,7 +520,9 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
458520
RedisModule_Free(model->data);
459521
}
460522

523+
#if 0
461524
TF_DeleteStatus(status);
525+
#endif
462526
}
463527

464528
int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
@@ -473,9 +537,9 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
473537
const size_t ninputs = array_len(mctxs[0]->inputs);
474538
const size_t noutputs = array_len(mctxs[0]->outputs);
475539
TF_Tensor *inputTensorsValues[ninputs];
476-
TF_Output inputs[ninputs];
477540
TF_Tensor *outputTensorsValues[noutputs];
478-
TF_Output outputs[noutputs];
541+
TFE_TensorHandle *inputTensorsHandles[ninputs];
542+
TFE_TensorHandle *outputTensorsHandles[noutputs];
479543

480544
size_t batch_sizes[nbatches];
481545
size_t batch_offsets[nbatches];
@@ -498,30 +562,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
498562
batched_input_tensors[b] = mctxs[b]->inputs[i].tensor;
499563
}
500564
inputTensorsValues[i] = RAI_TFTensorFromTensors(batched_input_tensors, nbatches);
501-
TF_Output port;
502-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->inputs[i].name);
503-
port.index = 0;
504-
if (port.oper == NULL) {
505-
return 1;
506-
}
507-
inputs[i] = port;
565+
inputTensorsHandles[i] = TFE_NewTensorHandle(inputTensorsValues[i], status);
566+
// TODO EAGER
567+
// check status and return error
508568
}
509569

510-
for (size_t i = 0; i < noutputs; ++i) {
511-
TF_Output port;
512-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->outputs[i].name);
513-
port.index = 0;
514-
if (port.oper == NULL) {
515-
return 1;
516-
}
517-
outputs[i] = port;
518-
}
570+
TFE_Op *fn_op = TFE_NewOp(mctxs[0]->model->session, RAI_TF_FN_NAME, status);
571+
// TODO EAGER
572+
// check status and return error
519573

520-
TF_SessionRun(mctxs[0]->model->session, NULL /* run_options */, inputs, inputTensorsValues,
521-
ninputs, outputs, outputTensorsValues, noutputs, NULL /* target_opers */,
522-
0 /* ntargets */, NULL /* run_Metadata */, status);
574+
TFE_OpAddInputList(fn_op, inputTensorsHandles, ninputs, status);
575+
// TODO EAGER
576+
// check status and return error
577+
578+
// TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
579+
580+
int noutputs_ = noutputs;
581+
TFE_Execute(fn_op, outputTensorsHandles, &noutputs_, status);
582+
// TODO EAGER
583+
// check status and return error
523584

524585
for (size_t i = 0; i < ninputs; ++i) {
586+
TFE_DeleteTensorHandle(inputTensorsHandles[i]);
525587
TF_DeleteTensor(inputTensorsValues[i]);
526588
}
527589

@@ -533,13 +595,25 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
533595
return 1;
534596
}
535597

598+
for (size_t i = 0; i < noutputs; ++i) {
599+
outputTensorsValues[i] = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
600+
601+
if (TF_GetCode(status) != TF_OK) {
602+
char *errorMessage = RedisModule_Strdup(TF_Message(status));
603+
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
604+
TF_DeleteStatus(status);
605+
RedisModule_Free(errorMessage);
606+
return 1;
607+
}
608+
}
609+
536610
for (size_t i = 0; i < noutputs; ++i) {
537611
if (nbatches > 1) {
538612
if (TF_NumDims(outputTensorsValues[i]) == 0) {
539613
continue;
540614
}
541615
if (TF_Dim(outputTensorsValues[i], 0) != total_batch_size) {
542-
TF_DeleteTensor(outputTensorsValues[i]);
616+
// TF_DeleteTensor(outputTensorsValues[i]);
543617
TF_DeleteStatus(status);
544618
RAI_SetError(error, RAI_EMODELRUN,
545619
"ERR Model did not generate the expected batch size");
@@ -554,7 +628,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
554628
mctxs[0]->outputs[i].tensor =
555629
RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
556630
}
557-
TF_DeleteTensor(outputTensorsValues[i]);
631+
// TF_DeleteTensor(outputTensorsValues[i]);
632+
TFE_DeleteTensorHandle(outputTensorsHandles[i]);
558633
}
559634

560635
TF_DeleteStatus(status);

0 commit comments

Comments
 (0)