6
6
#include "model.h"
7
7
8
8
#include "tensorflow/c/c_api.h"
9
+ #include "tensorflow/c/eager/c_api.h"
10
+
11
+ #define RAI_TF_FN_NAME "rai_tf_forward"
9
12
10
13
int RAI_InitBackendTF (int (* get_api_fn )(const char * , void * )) {
11
14
get_api_fn ("RedisModule_Alloc" , ((void * * )& RedisModule_Alloc ));
@@ -223,19 +226,15 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
223
226
RAI_SetError (error , RAI_EMODELIMPORT , "ERR unsupported device" );
224
227
}
225
228
226
- TF_Graph * model = TF_NewGraph ();
229
+ TF_Graph * graph = TF_NewGraph ();
230
+ TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
227
231
TF_Status * status = TF_NewStatus ();
228
232
TF_Buffer * tfbuffer = TF_NewBuffer ();
229
- TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
230
- TF_Status * optionsStatus = NULL ;
231
- TF_SessionOptions * sessionOptions = NULL ;
232
- TF_Status * sessionStatus = NULL ;
233
- TF_Session * session = NULL ;
234
233
235
234
tfbuffer -> length = modellen ;
236
235
tfbuffer -> data = modeldef ;
237
236
238
- TF_GraphImportGraphDef (model , tfbuffer , options , status );
237
+ TF_GraphImportGraphDef (graph , tfbuffer , options , status );
239
238
240
239
if (TF_GetCode (status ) != TF_OK ) {
241
240
char * errorMessage = RedisModule_Strdup (TF_Message (status ));
@@ -245,26 +244,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
245
244
}
246
245
247
246
for (size_t i = 0 ; i < ninputs ; ++ i ) {
248
- TF_Operation * oper = TF_GraphOperationByName (model , inputs [i ]);
247
+ TF_Operation * oper = TF_GraphOperationByName (graph , inputs [i ]);
249
248
if (oper == NULL || strcmp (TF_OperationOpType (oper ), "Placeholder" ) != 0 ) {
250
249
size_t len = strlen (inputs [i ]);
251
250
char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
252
251
sprintf (msg , "ERR Input node named \"%s\" not found in TF graph." , inputs [i ]);
253
252
RAI_SetError (error , RAI_EMODELIMPORT , msg );
254
253
RedisModule_Free (msg );
255
- goto cleanup ;
254
+ return NULL ;
256
255
}
257
256
}
258
257
259
258
for (size_t i = 0 ; i < noutputs ; ++ i ) {
260
- TF_Operation * oper = TF_GraphOperationByName (model , outputs [i ]);
259
+ TF_Operation * oper = TF_GraphOperationByName (graph , outputs [i ]);
261
260
if (oper == NULL ) {
262
261
size_t len = strlen (outputs [i ]);
263
262
char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
264
263
sprintf (msg , "ERR Output node named \"%s\" not found in TF graph" , outputs [i ]);
265
264
RAI_SetError (error , RAI_EMODELIMPORT , msg );
266
265
RedisModule_Free (msg );
267
- goto cleanup ;
266
+ return NULL ;
268
267
}
269
268
}
270
269
@@ -275,6 +274,65 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
275
274
TF_DeleteStatus (status );
276
275
status = NULL ;
277
276
277
+ TF_Output tf_inputs [ninputs ];
278
+ TF_Output tf_outputs [noutputs ];
279
+
280
+ for (size_t i = 0 ; i < ninputs ; ++ i ) {
281
+ TF_Output port ;
282
+ port .oper = TF_GraphOperationByName (graph , inputs [i ]);
283
+ port .index = 0 ;
284
+ if (port .oper == NULL ) {
285
+ return NULL ;
286
+ }
287
+ tf_inputs [i ] = port ;
288
+ }
289
+
290
+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
291
+ TF_Output port ;
292
+ port .oper = TF_GraphOperationByName (graph , outputs [i ]);
293
+ port .index = 0 ;
294
+ if (port .oper == NULL ) {
295
+ return NULL ;
296
+ }
297
+ tf_outputs [i ] = port ;
298
+ }
299
+
300
+ TF_Function * function = TF_GraphToFunction (
301
+ graph , // fn_body
302
+ RAI_TF_FN_NAME , 0 , // fn_name, append_hash_to_fn_name,
303
+ -1 , NULL , // num_opers, opers
304
+ ninputs , tf_inputs , // ninputs, inputs,
305
+ noutputs , tf_outputs , // noutputs, outputs
306
+ outputs , // output_names,
307
+ NULL , // opts
308
+ "" , // description
309
+ status // status
310
+ );
311
+ // TODO EAGER
312
+ // check status and return error
313
+
314
+ TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
315
+ // TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
316
+ // TFE_ContextOptionsSetAsync(context_opts, 0);
317
+ TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
318
+
319
+ TFE_Context * context = TFE_NewContext (context_opts , status );
320
+ // TODO EAGER
321
+ // check status and return error
322
+
323
+ TFE_ContextAddFunction (context , function , status );
324
+ // TODO EAGER
325
+ // check status and return error
326
+
327
+ TFE_DeleteContextOptions (context_opts );
328
+ TFE_DeleteContext (context );
329
+
330
+ #if 0
331
+ TF_Status * optionsStatus = NULL ;
332
+ TF_SessionOptions * sessionOptions = NULL ;
333
+ TF_Status * sessionStatus = NULL ;
334
+ TF_Session * session = NULL ;
335
+
278
336
optionsStatus = TF_NewStatus ();
279
337
sessionOptions = TF_NewSessionOptions ();
280
338
@@ -340,7 +398,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
340
398
optionsStatus = NULL ;
341
399
342
400
sessionStatus = TF_NewStatus ();
343
- session = TF_NewSession (model , sessionOptions , sessionStatus );
401
+ session = TF_NewSession (graph , sessionOptions , sessionStatus );
344
402
345
403
TF_Status * deviceListStatus = TF_NewStatus ();
346
404
TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
@@ -370,6 +428,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
370
428
371
429
TF_DeleteSessionOptions (sessionOptions );
372
430
TF_DeleteStatus (sessionStatus );
431
+ #endif
373
432
374
433
char * * inputs_ = array_new (char * , ninputs );
375
434
for (long long i = 0 ; i < ninputs ; i ++ ) {
@@ -385,8 +444,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
385
444
memcpy (buffer , modeldef , modellen );
386
445
387
446
RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
388
- ret -> model = model ;
389
- ret -> session = session ;
447
+ ret -> model = graph ;
448
+ ret -> session = context ;
390
449
ret -> backend = backend ;
391
450
ret -> devicestr = RedisModule_Strdup (devicestr );
392
451
ret -> ninputs = ninputs ;
@@ -401,22 +460,23 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
401
460
return ret ;
402
461
403
462
cleanup :
404
- TF_DeleteGraph (model );
463
+ TF_DeleteGraph (graph );
405
464
if (options )
406
465
TF_DeleteImportGraphDefOptions (options );
407
466
if (tfbuffer )
408
467
TF_DeleteBuffer (tfbuffer );
409
468
if (status )
410
469
TF_DeleteStatus (status );
411
- if (sessionOptions )
412
- TF_DeleteSessionOptions (sessionOptions );
413
- if (sessionStatus )
414
- TF_DeleteStatus (sessionStatus );
470
+ // if (sessionOptions)
471
+ // TF_DeleteSessionOptions(sessionOptions);
472
+ // if (sessionStatus)
473
+ // TF_DeleteStatus(sessionStatus);
415
474
return NULL ;
416
475
}
417
476
418
477
void RAI_ModelFreeTF (RAI_Model * model , RAI_Error * error ) {
419
478
TF_Status * status = TF_NewStatus ();
479
+ #if 0
420
480
TF_CloseSession (model -> session , status );
421
481
422
482
if (TF_GetCode (status ) != TF_OK ) {
@@ -425,12 +485,14 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
425
485
}
426
486
427
487
TF_DeleteSession (model -> session , status );
488
+ #endif
489
+ TFE_DeleteContext (model -> session );
428
490
model -> session = NULL ;
429
491
430
- if (TF_GetCode (status ) != TF_OK ) {
431
- RAI_SetError (error , RAI_EMODELFREE , RedisModule_Strdup (TF_Message (status )));
432
- return ;
433
- }
492
+ // if (TF_GetCode(status) != TF_OK) {
493
+ // RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
494
+ // return;
495
+ // }
434
496
435
497
TF_DeleteGraph (model -> model );
436
498
model -> model = NULL ;
@@ -457,7 +519,9 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
457
519
RedisModule_Free (model -> data );
458
520
}
459
521
522
+ #if 0
460
523
TF_DeleteStatus (status );
524
+ #endif
461
525
}
462
526
463
527
int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
@@ -472,9 +536,9 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
472
536
const size_t ninputs = array_len (mctxs [0 ]-> inputs );
473
537
const size_t noutputs = array_len (mctxs [0 ]-> outputs );
474
538
TF_Tensor * inputTensorsValues [ninputs ];
475
- TF_Output inputs [ninputs ];
476
539
TF_Tensor * outputTensorsValues [noutputs ];
477
- TF_Output outputs [noutputs ];
540
+ TFE_TensorHandle * inputTensorsHandles [ninputs ];
541
+ TFE_TensorHandle * outputTensorsHandles [noutputs ];
478
542
479
543
size_t batch_sizes [nbatches ];
480
544
size_t batch_offsets [nbatches ];
@@ -497,30 +561,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
497
561
batched_input_tensors [b ] = mctxs [b ]-> inputs [i ].tensor ;
498
562
}
499
563
inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
500
- TF_Output port ;
501
- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> inputs [i ].name );
502
- port .index = 0 ;
503
- if (port .oper == NULL ) {
504
- return 1 ;
505
- }
506
- inputs [i ] = port ;
564
+ inputTensorsHandles [i ] = TFE_NewTensorHandle (inputTensorsValues [i ], status );
565
+ // TODO EAGER
566
+ // check status and return error
507
567
}
508
568
509
- for (size_t i = 0 ; i < noutputs ; ++ i ) {
510
- TF_Output port ;
511
- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> outputs [i ].name );
512
- port .index = 0 ;
513
- if (port .oper == NULL ) {
514
- return 1 ;
515
- }
516
- outputs [i ] = port ;
517
- }
569
+ TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
570
+ // TODO EAGER
571
+ // check status and return error
518
572
519
- TF_SessionRun (mctxs [0 ]-> model -> session , NULL /* run_options */ , inputs , inputTensorsValues ,
520
- ninputs , outputs , outputTensorsValues , noutputs , NULL /* target_opers */ ,
521
- 0 /* ntargets */ , NULL /* run_Metadata */ , status );
573
+ TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
574
+ // TODO EAGER
575
+ // check status and return error
576
+
577
+ // TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
578
+
579
+ int noutputs_ = noutputs ;
580
+ TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
581
+ // TODO EAGER
582
+ // check status and return error
522
583
523
584
for (size_t i = 0 ; i < ninputs ; ++ i ) {
585
+ TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
524
586
TF_DeleteTensor (inputTensorsValues [i ]);
525
587
}
526
588
@@ -532,13 +594,25 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
532
594
return 1 ;
533
595
}
534
596
597
+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
598
+ outputTensorsValues [i ] = TFE_TensorHandleResolve (outputTensorsHandles [i ], status );
599
+
600
+ if (TF_GetCode (status ) != TF_OK ) {
601
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
602
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
603
+ TF_DeleteStatus (status );
604
+ RedisModule_Free (errorMessage );
605
+ return 1 ;
606
+ }
607
+ }
608
+
535
609
for (size_t i = 0 ; i < noutputs ; ++ i ) {
536
610
if (nbatches > 1 ) {
537
611
if (TF_NumDims (outputTensorsValues [i ]) == 0 ) {
538
612
continue ;
539
613
}
540
614
if (TF_Dim (outputTensorsValues [i ], 0 ) != total_batch_size ) {
541
- TF_DeleteTensor (outputTensorsValues [i ]);
615
+ // TF_DeleteTensor(outputTensorsValues[i]);
542
616
TF_DeleteStatus (status );
543
617
RAI_SetError (error , RAI_EMODELRUN ,
544
618
"ERR Model did not generate the expected batch size" );
@@ -553,7 +627,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
553
627
mctxs [0 ]-> outputs [i ].tensor =
554
628
RAI_TensorCreateFromTFTensor (outputTensorsValues [i ], 0 , -1 );
555
629
}
556
- TF_DeleteTensor (outputTensorsValues [i ]);
630
+ // TF_DeleteTensor(outputTensorsValues[i]);
631
+ TFE_DeleteTensorHandle (outputTensorsHandles [i ]);
557
632
}
558
633
559
634
TF_DeleteStatus (status );
0 commit comments