Skip to content

Commit bc6624a

Browse files
committed
Moving to eager mode
1 parent adcd882 commit bc6624a

File tree

1 file changed

+121
-46
lines changed

1 file changed

+121
-46
lines changed

src/backends/tensorflow.c

Lines changed: 121 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include "model.h"
77

88
#include "tensorflow/c/c_api.h"
9+
#include "tensorflow/c/eager/c_api.h"
10+
11+
#define RAI_TF_FN_NAME "rai_tf_forward"
912

1013
int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
1114
get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc));
@@ -223,19 +226,15 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
223226
RAI_SetError(error, RAI_EMODELIMPORT, "ERR unsupported device");
224227
}
225228

226-
TF_Graph *model = TF_NewGraph();
229+
TF_Graph *graph = TF_NewGraph();
230+
TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions();
227231
TF_Status *status = TF_NewStatus();
228232
TF_Buffer *tfbuffer = TF_NewBuffer();
229-
TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions();
230-
TF_Status *optionsStatus = NULL;
231-
TF_SessionOptions *sessionOptions = NULL;
232-
TF_Status *sessionStatus = NULL;
233-
TF_Session *session = NULL;
234233

235234
tfbuffer->length = modellen;
236235
tfbuffer->data = modeldef;
237236

238-
TF_GraphImportGraphDef(model, tfbuffer, options, status);
237+
TF_GraphImportGraphDef(graph, tfbuffer, options, status);
239238

240239
if (TF_GetCode(status) != TF_OK) {
241240
char *errorMessage = RedisModule_Strdup(TF_Message(status));
@@ -245,26 +244,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
245244
}
246245

247246
for (size_t i = 0; i < ninputs; ++i) {
248-
TF_Operation *oper = TF_GraphOperationByName(model, inputs[i]);
247+
TF_Operation *oper = TF_GraphOperationByName(graph, inputs[i]);
249248
if (oper == NULL || strcmp(TF_OperationOpType(oper), "Placeholder") != 0) {
250249
size_t len = strlen(inputs[i]);
251250
char *msg = RedisModule_Calloc(60 + len, sizeof(*msg));
252251
sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", inputs[i]);
253252
RAI_SetError(error, RAI_EMODELIMPORT, msg);
254253
RedisModule_Free(msg);
255-
goto cleanup;
254+
return NULL;
256255
}
257256
}
258257

259258
for (size_t i = 0; i < noutputs; ++i) {
260-
TF_Operation *oper = TF_GraphOperationByName(model, outputs[i]);
259+
TF_Operation *oper = TF_GraphOperationByName(graph, outputs[i]);
261260
if (oper == NULL) {
262261
size_t len = strlen(outputs[i]);
263262
char *msg = RedisModule_Calloc(60 + len, sizeof(*msg));
264263
sprintf(msg, "ERR Output node named \"%s\" not found in TF graph", outputs[i]);
265264
RAI_SetError(error, RAI_EMODELIMPORT, msg);
266265
RedisModule_Free(msg);
267-
goto cleanup;
266+
return NULL;
268267
}
269268
}
270269

@@ -275,6 +274,65 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
275274
TF_DeleteStatus(status);
276275
status = NULL;
277276

277+
TF_Output tf_inputs[ninputs];
278+
TF_Output tf_outputs[noutputs];
279+
280+
for (size_t i = 0; i < ninputs; ++i) {
281+
TF_Output port;
282+
port.oper = TF_GraphOperationByName(graph, inputs[i]);
283+
port.index = 0;
284+
if (port.oper == NULL) {
285+
return NULL;
286+
}
287+
tf_inputs[i] = port;
288+
}
289+
290+
for (size_t i = 0; i < noutputs; ++i) {
291+
TF_Output port;
292+
port.oper = TF_GraphOperationByName(graph, outputs[i]);
293+
port.index = 0;
294+
if (port.oper == NULL) {
295+
return NULL;
296+
}
297+
tf_outputs[i] = port;
298+
}
299+
300+
TF_Function *function = TF_GraphToFunction(
301+
graph, // fn_body
302+
RAI_TF_FN_NAME, 0, // fn_name, append_hash_to_fn_name,
303+
-1, NULL, // num_opers, opers
304+
ninputs, tf_inputs, // ninputs, inputs,
305+
noutputs, tf_outputs, // noutputs, outputs
306+
outputs, // output_names,
307+
NULL, // opts
308+
"", // description
309+
status // status
310+
);
311+
// TODO EAGER
312+
// check status and return error
313+
314+
TFE_ContextOptions *context_opts = TFE_NewContextOptions();
315+
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
316+
// TFE_ContextOptionsSetAsync(context_opts, 0);
317+
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
318+
319+
TFE_Context *context = TFE_NewContext(context_opts, status);
320+
// TODO EAGER
321+
// check status and return error
322+
323+
TFE_ContextAddFunction(context, function, status);
324+
// TODO EAGER
325+
// check status and return error
326+
327+
TFE_DeleteContextOptions(context_opts);
328+
TFE_DeleteContext(context);
329+
330+
#if 0
331+
TF_Status *optionsStatus = NULL;
332+
TF_SessionOptions *sessionOptions = NULL;
333+
TF_Status *sessionStatus = NULL;
334+
TF_Session *session = NULL;
335+
278336
optionsStatus = TF_NewStatus();
279337
sessionOptions = TF_NewSessionOptions();
280338

@@ -340,7 +398,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
340398
optionsStatus = NULL;
341399

342400
sessionStatus = TF_NewStatus();
343-
session = TF_NewSession(model, sessionOptions, sessionStatus);
401+
session = TF_NewSession(graph, sessionOptions, sessionStatus);
344402

345403
TF_Status *deviceListStatus = TF_NewStatus();
346404
TF_DeviceList *deviceList = TF_SessionListDevices(session, deviceListStatus);
@@ -370,6 +428,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
370428

371429
TF_DeleteSessionOptions(sessionOptions);
372430
TF_DeleteStatus(sessionStatus);
431+
#endif
373432

374433
char **inputs_ = array_new(char *, ninputs);
375434
for (long long i = 0; i < ninputs; i++) {
@@ -385,8 +444,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
385444
memcpy(buffer, modeldef, modellen);
386445

387446
RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret));
388-
ret->model = model;
389-
ret->session = session;
447+
ret->model = graph;
448+
ret->session = context;
390449
ret->backend = backend;
391450
ret->devicestr = RedisModule_Strdup(devicestr);
392451
ret->ninputs = ninputs;
@@ -401,22 +460,23 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
401460
return ret;
402461

403462
cleanup:
404-
TF_DeleteGraph(model);
463+
TF_DeleteGraph(graph);
405464
if (options)
406465
TF_DeleteImportGraphDefOptions(options);
407466
if (tfbuffer)
408467
TF_DeleteBuffer(tfbuffer);
409468
if (status)
410469
TF_DeleteStatus(status);
411-
if (sessionOptions)
412-
TF_DeleteSessionOptions(sessionOptions);
413-
if (sessionStatus)
414-
TF_DeleteStatus(sessionStatus);
470+
// if (sessionOptions)
471+
// TF_DeleteSessionOptions(sessionOptions);
472+
// if (sessionStatus)
473+
// TF_DeleteStatus(sessionStatus);
415474
return NULL;
416475
}
417476

418477
void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
419478
TF_Status *status = TF_NewStatus();
479+
#if 0
420480
TF_CloseSession(model->session, status);
421481

422482
if (TF_GetCode(status) != TF_OK) {
@@ -425,12 +485,14 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
425485
}
426486

427487
TF_DeleteSession(model->session, status);
488+
#endif
489+
TFE_DeleteContext(model->session);
428490
model->session = NULL;
429491

430-
if (TF_GetCode(status) != TF_OK) {
431-
RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
432-
return;
433-
}
492+
// if (TF_GetCode(status) != TF_OK) {
493+
// RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
494+
// return;
495+
// }
434496

435497
TF_DeleteGraph(model->model);
436498
model->model = NULL;
@@ -457,7 +519,9 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
457519
RedisModule_Free(model->data);
458520
}
459521

522+
#if 0
460523
TF_DeleteStatus(status);
524+
#endif
461525
}
462526

463527
int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
@@ -472,9 +536,9 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
472536
const size_t ninputs = array_len(mctxs[0]->inputs);
473537
const size_t noutputs = array_len(mctxs[0]->outputs);
474538
TF_Tensor *inputTensorsValues[ninputs];
475-
TF_Output inputs[ninputs];
476539
TF_Tensor *outputTensorsValues[noutputs];
477-
TF_Output outputs[noutputs];
540+
TFE_TensorHandle *inputTensorsHandles[ninputs];
541+
TFE_TensorHandle *outputTensorsHandles[noutputs];
478542

479543
size_t batch_sizes[nbatches];
480544
size_t batch_offsets[nbatches];
@@ -497,30 +561,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
497561
batched_input_tensors[b] = mctxs[b]->inputs[i].tensor;
498562
}
499563
inputTensorsValues[i] = RAI_TFTensorFromTensors(batched_input_tensors, nbatches);
500-
TF_Output port;
501-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->inputs[i].name);
502-
port.index = 0;
503-
if (port.oper == NULL) {
504-
return 1;
505-
}
506-
inputs[i] = port;
564+
inputTensorsHandles[i] = TFE_NewTensorHandle(inputTensorsValues[i], status);
565+
// TODO EAGER
566+
// check status and return error
507567
}
508568

509-
for (size_t i = 0; i < noutputs; ++i) {
510-
TF_Output port;
511-
port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->outputs[i].name);
512-
port.index = 0;
513-
if (port.oper == NULL) {
514-
return 1;
515-
}
516-
outputs[i] = port;
517-
}
569+
TFE_Op *fn_op = TFE_NewOp(mctxs[0]->model->session, RAI_TF_FN_NAME, status);
570+
// TODO EAGER
571+
// check status and return error
518572

519-
TF_SessionRun(mctxs[0]->model->session, NULL /* run_options */, inputs, inputTensorsValues,
520-
ninputs, outputs, outputTensorsValues, noutputs, NULL /* target_opers */,
521-
0 /* ntargets */, NULL /* run_Metadata */, status);
573+
TFE_OpAddInputList(fn_op, inputTensorsHandles, ninputs, status);
574+
// TODO EAGER
575+
// check status and return error
576+
577+
// TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
578+
579+
int noutputs_ = noutputs;
580+
TFE_Execute(fn_op, outputTensorsHandles, &noutputs_, status);
581+
// TODO EAGER
582+
// check status and return error
522583

523584
for (size_t i = 0; i < ninputs; ++i) {
585+
TFE_DeleteTensorHandle(inputTensorsHandles[i]);
524586
TF_DeleteTensor(inputTensorsValues[i]);
525587
}
526588

@@ -532,13 +594,25 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
532594
return 1;
533595
}
534596

597+
for (size_t i = 0; i < noutputs; ++i) {
598+
outputTensorsValues[i] = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
599+
600+
if (TF_GetCode(status) != TF_OK) {
601+
char *errorMessage = RedisModule_Strdup(TF_Message(status));
602+
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
603+
TF_DeleteStatus(status);
604+
RedisModule_Free(errorMessage);
605+
return 1;
606+
}
607+
}
608+
535609
for (size_t i = 0; i < noutputs; ++i) {
536610
if (nbatches > 1) {
537611
if (TF_NumDims(outputTensorsValues[i]) == 0) {
538612
continue;
539613
}
540614
if (TF_Dim(outputTensorsValues[i], 0) != total_batch_size) {
541-
TF_DeleteTensor(outputTensorsValues[i]);
615+
// TF_DeleteTensor(outputTensorsValues[i]);
542616
TF_DeleteStatus(status);
543617
RAI_SetError(error, RAI_EMODELRUN,
544618
"ERR Model did not generate the expected batch size");
@@ -553,7 +627,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
553627
mctxs[0]->outputs[i].tensor =
554628
RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
555629
}
556-
TF_DeleteTensor(outputTensorsValues[i]);
630+
// TF_DeleteTensor(outputTensorsValues[i]);
631+
TFE_DeleteTensorHandle(outputTensorsHandles[i]);
557632
}
558633

559634
TF_DeleteStatus(status);

0 commit comments

Comments
 (0)