6
6
#include "model.h"
7
7
8
8
#include "tensorflow/c/c_api.h"
9
+ #include "tensorflow/c/eager/c_api.h"
10
+
11
+ #define RAI_TF_FN_NAME "rai_tf_forward"
9
12
10
13
int RAI_InitBackendTF (int (* get_api_fn )(const char * , void * )) {
11
14
get_api_fn ("RedisModule_Alloc" , ((void * * )& RedisModule_Alloc ));
@@ -224,19 +227,15 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
224
227
RAI_SetError (error , RAI_EMODELIMPORT , "ERR unsupported device" );
225
228
}
226
229
227
- TF_Graph * model = TF_NewGraph ();
230
+ TF_Graph * graph = TF_NewGraph ();
231
+ TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
228
232
TF_Status * status = TF_NewStatus ();
229
233
TF_Buffer * tfbuffer = TF_NewBuffer ();
230
- TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
231
- TF_Status * optionsStatus = NULL ;
232
- TF_SessionOptions * sessionOptions = NULL ;
233
- TF_Status * sessionStatus = NULL ;
234
- TF_Session * session = NULL ;
235
234
236
235
tfbuffer -> length = modellen ;
237
236
tfbuffer -> data = modeldef ;
238
237
239
- TF_GraphImportGraphDef (model , tfbuffer , options , status );
238
+ TF_GraphImportGraphDef (graph , tfbuffer , options , status );
240
239
241
240
if (TF_GetCode (status ) != TF_OK ) {
242
241
char * errorMessage = RedisModule_Strdup (TF_Message (status ));
@@ -246,26 +245,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
246
245
}
247
246
248
247
for (size_t i = 0 ; i < ninputs ; ++ i ) {
249
- TF_Operation * oper = TF_GraphOperationByName (model , inputs [i ]);
248
+ TF_Operation * oper = TF_GraphOperationByName (graph , inputs [i ]);
250
249
if (oper == NULL ) {
251
250
size_t len = strlen (inputs [i ]);
252
251
char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
253
252
sprintf (msg , "ERR Input node named \"%s\" not found in TF graph." , inputs [i ]);
254
253
RAI_SetError (error , RAI_EMODELIMPORT , msg );
255
254
RedisModule_Free (msg );
256
- goto cleanup ;
255
+ return NULL ;
257
256
}
258
257
}
259
258
260
259
for (size_t i = 0 ; i < noutputs ; ++ i ) {
261
- TF_Operation * oper = TF_GraphOperationByName (model , outputs [i ]);
260
+ TF_Operation * oper = TF_GraphOperationByName (graph , outputs [i ]);
262
261
if (oper == NULL ) {
263
262
size_t len = strlen (outputs [i ]);
264
263
char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
265
264
sprintf (msg , "ERR Output node named \"%s\" not found in TF graph" , outputs [i ]);
266
265
RAI_SetError (error , RAI_EMODELIMPORT , msg );
267
266
RedisModule_Free (msg );
268
- goto cleanup ;
267
+ return NULL ;
269
268
}
270
269
}
271
270
@@ -276,6 +275,65 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
276
275
TF_DeleteStatus (status );
277
276
status = NULL ;
278
277
278
+ TF_Output tf_inputs [ninputs ];
279
+ TF_Output tf_outputs [noutputs ];
280
+
281
+ for (size_t i = 0 ; i < ninputs ; ++ i ) {
282
+ TF_Output port ;
283
+ port .oper = TF_GraphOperationByName (graph , inputs [i ]);
284
+ port .index = 0 ;
285
+ if (port .oper == NULL ) {
286
+ return NULL ;
287
+ }
288
+ tf_inputs [i ] = port ;
289
+ }
290
+
291
+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
292
+ TF_Output port ;
293
+ port .oper = TF_GraphOperationByName (graph , outputs [i ]);
294
+ port .index = 0 ;
295
+ if (port .oper == NULL ) {
296
+ return NULL ;
297
+ }
298
+ tf_outputs [i ] = port ;
299
+ }
300
+
301
+ TF_Function * function = TF_GraphToFunction (
302
+ graph , // fn_body
303
+ RAI_TF_FN_NAME , 0 , // fn_name, append_hash_to_fn_name,
304
+ -1 , NULL , // num_opers, opers
305
+ ninputs , tf_inputs , // ninputs, inputs,
306
+ noutputs , tf_outputs , // noutputs, outputs
307
+ outputs , // output_names,
308
+ NULL , // opts
309
+ "" , // description
310
+ status // status
311
+ );
312
+ // TODO EAGER
313
+ // check status and return error
314
+
315
+ TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
316
+ // TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
317
+ // TFE_ContextOptionsSetAsync(context_opts, 0);
318
+ TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
319
+
320
+ TFE_Context * context = TFE_NewContext (context_opts , status );
321
+ // TODO EAGER
322
+ // check status and return error
323
+
324
+ TFE_ContextAddFunction (context , function , status );
325
+ // TODO EAGER
326
+ // check status and return error
327
+
328
+ TFE_DeleteContextOptions (context_opts );
329
+ TFE_DeleteContext (context );
330
+
331
+ #if 0
332
+ TF_Status * optionsStatus = NULL ;
333
+ TF_SessionOptions * sessionOptions = NULL ;
334
+ TF_Status * sessionStatus = NULL ;
335
+ TF_Session * session = NULL ;
336
+
279
337
optionsStatus = TF_NewStatus ();
280
338
sessionOptions = TF_NewSessionOptions ();
281
339
@@ -341,7 +399,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
341
399
optionsStatus = NULL ;
342
400
343
401
sessionStatus = TF_NewStatus ();
344
- session = TF_NewSession (model , sessionOptions , sessionStatus );
402
+ session = TF_NewSession (graph , sessionOptions , sessionStatus );
345
403
346
404
TF_Status * deviceListStatus = TF_NewStatus ();
347
405
TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
@@ -371,6 +429,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
371
429
372
430
TF_DeleteSessionOptions (sessionOptions );
373
431
TF_DeleteStatus (sessionStatus );
432
+ #endif
374
433
375
434
char * * inputs_ = array_new (char * , ninputs );
376
435
for (long long i = 0 ; i < ninputs ; i ++ ) {
@@ -386,8 +445,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
386
445
memcpy (buffer , modeldef , modellen );
387
446
388
447
RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
389
- ret -> model = model ;
390
- ret -> session = session ;
448
+ ret -> model = graph ;
449
+ ret -> session = context ;
391
450
ret -> backend = backend ;
392
451
ret -> devicestr = RedisModule_Strdup (devicestr );
393
452
ret -> ninputs = ninputs ;
@@ -402,22 +461,23 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
402
461
return ret ;
403
462
404
463
cleanup :
405
- TF_DeleteGraph (model );
464
+ TF_DeleteGraph (graph );
406
465
if (options )
407
466
TF_DeleteImportGraphDefOptions (options );
408
467
if (tfbuffer )
409
468
TF_DeleteBuffer (tfbuffer );
410
469
if (status )
411
470
TF_DeleteStatus (status );
412
- if (sessionOptions )
413
- TF_DeleteSessionOptions (sessionOptions );
414
- if (sessionStatus )
415
- TF_DeleteStatus (sessionStatus );
471
+ // if (sessionOptions)
472
+ // TF_DeleteSessionOptions(sessionOptions);
473
+ // if (sessionStatus)
474
+ // TF_DeleteStatus(sessionStatus);
416
475
return NULL ;
417
476
}
418
477
419
478
void RAI_ModelFreeTF (RAI_Model * model , RAI_Error * error ) {
420
479
TF_Status * status = TF_NewStatus ();
480
+ #if 0
421
481
TF_CloseSession (model -> session , status );
422
482
423
483
if (TF_GetCode (status ) != TF_OK ) {
@@ -426,12 +486,14 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
426
486
}
427
487
428
488
TF_DeleteSession (model -> session , status );
489
+ #endif
490
+ TFE_DeleteContext (model -> session );
429
491
model -> session = NULL ;
430
492
431
- if (TF_GetCode (status ) != TF_OK ) {
432
- RAI_SetError (error , RAI_EMODELFREE , RedisModule_Strdup (TF_Message (status )));
433
- return ;
434
- }
493
+ // if (TF_GetCode(status) != TF_OK) {
494
+ // RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
495
+ // return;
496
+ // }
435
497
436
498
TF_DeleteGraph (model -> model );
437
499
model -> model = NULL ;
@@ -458,7 +520,9 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
458
520
RedisModule_Free (model -> data );
459
521
}
460
522
523
+ #if 0
461
524
TF_DeleteStatus (status );
525
+ #endif
462
526
}
463
527
464
528
int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
@@ -473,9 +537,9 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
473
537
const size_t ninputs = array_len (mctxs [0 ]-> inputs );
474
538
const size_t noutputs = array_len (mctxs [0 ]-> outputs );
475
539
TF_Tensor * inputTensorsValues [ninputs ];
476
- TF_Output inputs [ninputs ];
477
540
TF_Tensor * outputTensorsValues [noutputs ];
478
- TF_Output outputs [noutputs ];
541
+ TFE_TensorHandle * inputTensorsHandles [ninputs ];
542
+ TFE_TensorHandle * outputTensorsHandles [noutputs ];
479
543
480
544
size_t batch_sizes [nbatches ];
481
545
size_t batch_offsets [nbatches ];
@@ -498,30 +562,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
498
562
batched_input_tensors [b ] = mctxs [b ]-> inputs [i ].tensor ;
499
563
}
500
564
inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
501
- TF_Output port ;
502
- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> inputs [i ].name );
503
- port .index = 0 ;
504
- if (port .oper == NULL ) {
505
- return 1 ;
506
- }
507
- inputs [i ] = port ;
565
+ inputTensorsHandles [i ] = TFE_NewTensorHandle (inputTensorsValues [i ], status );
566
+ // TODO EAGER
567
+ // check status and return error
508
568
}
509
569
510
- for (size_t i = 0 ; i < noutputs ; ++ i ) {
511
- TF_Output port ;
512
- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> outputs [i ].name );
513
- port .index = 0 ;
514
- if (port .oper == NULL ) {
515
- return 1 ;
516
- }
517
- outputs [i ] = port ;
518
- }
570
+ TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
571
+ // TODO EAGER
572
+ // check status and return error
519
573
520
- TF_SessionRun (mctxs [0 ]-> model -> session , NULL /* run_options */ , inputs , inputTensorsValues ,
521
- ninputs , outputs , outputTensorsValues , noutputs , NULL /* target_opers */ ,
522
- 0 /* ntargets */ , NULL /* run_Metadata */ , status );
574
+ TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
575
+ // TODO EAGER
576
+ // check status and return error
577
+
578
+ // TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
579
+
580
+ int noutputs_ = noutputs ;
581
+ TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
582
+ // TODO EAGER
583
+ // check status and return error
523
584
524
585
for (size_t i = 0 ; i < ninputs ; ++ i ) {
586
+ TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
525
587
TF_DeleteTensor (inputTensorsValues [i ]);
526
588
}
527
589
@@ -533,13 +595,25 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
533
595
return 1 ;
534
596
}
535
597
598
+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
599
+ outputTensorsValues [i ] = TFE_TensorHandleResolve (outputTensorsHandles [i ], status );
600
+
601
+ if (TF_GetCode (status ) != TF_OK ) {
602
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
603
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
604
+ TF_DeleteStatus (status );
605
+ RedisModule_Free (errorMessage );
606
+ return 1 ;
607
+ }
608
+ }
609
+
536
610
for (size_t i = 0 ; i < noutputs ; ++ i ) {
537
611
if (nbatches > 1 ) {
538
612
if (TF_NumDims (outputTensorsValues [i ]) == 0 ) {
539
613
continue ;
540
614
}
541
615
if (TF_Dim (outputTensorsValues [i ], 0 ) != total_batch_size ) {
542
- TF_DeleteTensor (outputTensorsValues [i ]);
616
+ // TF_DeleteTensor(outputTensorsValues[i]);
543
617
TF_DeleteStatus (status );
544
618
RAI_SetError (error , RAI_EMODELRUN ,
545
619
"ERR Model did not generate the expected batch size" );
@@ -554,7 +628,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
554
628
mctxs [0 ]-> outputs [i ].tensor =
555
629
RAI_TensorCreateFromTFTensor (outputTensorsValues [i ], 0 , -1 );
556
630
}
557
- TF_DeleteTensor (outputTensorsValues [i ]);
631
+ // TF_DeleteTensor(outputTensorsValues[i]);
632
+ TFE_DeleteTensorHandle (outputTensorsHandles [i ]);
558
633
}
559
634
560
635
TF_DeleteStatus (status );
0 commit comments