@@ -271,8 +271,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
271
271
options = NULL ;
272
272
TF_DeleteBuffer (tfbuffer );
273
273
tfbuffer = NULL ;
274
- TF_DeleteStatus (status );
275
- status = NULL ;
276
274
277
275
TF_Output tf_inputs [ninputs ];
278
276
TF_Output tf_outputs [noutputs ];
@@ -305,37 +303,37 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
305
303
noutputs , tf_outputs , // noutputs, outputs
306
304
outputs , // output_names,
307
305
NULL , // opts
308
- "" , // description
306
+ NULL , // description
309
307
status // status
310
308
);
311
- // TODO EAGER
312
- // check status and return error
309
+
310
+ if (TF_GetCode (status ) != TF_OK ) {
311
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
312
+ goto cleanup ;
313
+ }
313
314
314
315
TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
315
316
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
316
317
// TFE_ContextOptionsSetAsync(context_opts, 0);
317
- TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
318
+ // TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
318
319
319
320
TFE_Context * context = TFE_NewContext (context_opts , status );
320
- // TODO EAGER
321
- // check status and return error
321
+ if (TF_GetCode (status ) != TF_OK ) {
322
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
323
+ goto cleanup ;
324
+ }
322
325
323
326
TFE_ContextAddFunction (context , function , status );
324
- // TODO EAGER
325
- // check status and return error
327
+ if (TF_GetCode (status ) != TF_OK ) {
328
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
329
+ goto cleanup ;
330
+ }
326
331
327
332
TFE_DeleteContextOptions (context_opts );
328
- TFE_DeleteContext (context );
329
333
330
- #if 0
331
- TF_Status * optionsStatus = NULL ;
332
- TF_SessionOptions * sessionOptions = NULL ;
333
- TF_Status * sessionStatus = NULL ;
334
- TF_Session * session = NULL ;
335
-
336
- optionsStatus = TF_NewStatus ();
337
- sessionOptions = TF_NewSessionOptions ();
334
+ TF_DeleteStatus (status );
338
335
336
+ #if 0
339
337
// For setting config options in session from the C API see:
340
338
// https://github.com/tensorflow/tensorflow/issues/13853
341
339
// import tensorflow as tf
@@ -390,16 +388,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
390
388
}
391
389
}
392
390
393
- if (TF_GetCode (optionsStatus ) != TF_OK ) {
394
- RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (optionsStatus )));
395
- goto cleanup ;
396
- }
397
- TF_DeleteStatus (optionsStatus );
398
- optionsStatus = NULL ;
399
-
400
- sessionStatus = TF_NewStatus ();
401
- session = TF_NewSession (graph , sessionOptions , sessionStatus );
402
-
403
391
TF_Status * deviceListStatus = TF_NewStatus ();
404
392
TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
405
393
const int num_devices = TF_DeviceListCount (deviceList );
@@ -425,9 +413,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
425
413
RAI_SetError (error , RAI_EMODELCREATE , RedisModule_Strdup (TF_Message (status )));
426
414
goto cleanup ;
427
415
}
428
-
429
- TF_DeleteSessionOptions (sessionOptions );
430
- TF_DeleteStatus (sessionStatus );
431
416
#endif
432
417
433
418
char * * inputs_ = array_new (char * , ninputs );
@@ -467,33 +452,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
467
452
TF_DeleteBuffer (tfbuffer );
468
453
if (status )
469
454
TF_DeleteStatus (status );
470
- // if (sessionOptions)
471
- // TF_DeleteSessionOptions(sessionOptions);
472
- // if (sessionStatus)
473
- // TF_DeleteStatus(sessionStatus);
474
455
return NULL ;
475
456
}
476
457
477
458
void RAI_ModelFreeTF (RAI_Model * model , RAI_Error * error ) {
478
- TF_Status * status = TF_NewStatus ();
479
- #if 0
480
- TF_CloseSession (model -> session , status );
481
-
482
- if (TF_GetCode (status ) != TF_OK ) {
483
- RAI_SetError (error , RAI_EMODELFREE , RedisModule_Strdup (TF_Message (status )));
484
- return ;
485
- }
486
-
487
- TF_DeleteSession (model -> session , status );
488
- #endif
489
459
TFE_DeleteContext (model -> session );
490
460
model -> session = NULL ;
491
461
492
- // if (TF_GetCode(status) != TF_OK) {
493
- // RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
494
- // return;
495
- // }
496
-
497
462
TF_DeleteGraph (model -> model );
498
463
model -> model = NULL ;
499
464
@@ -518,10 +483,6 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
518
483
if (model -> data ) {
519
484
RedisModule_Free (model -> data );
520
485
}
521
-
522
- #if 0
523
- TF_DeleteStatus (status );
524
- #endif
525
486
}
526
487
527
488
int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
@@ -562,24 +523,44 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
562
523
}
563
524
inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
564
525
inputTensorsHandles [i ] = TFE_NewTensorHandle (inputTensorsValues [i ], status );
565
- // TODO EAGER
566
- // check status and return error
526
+ if (TF_GetCode (status ) != TF_OK ) {
527
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
528
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
529
+ TF_DeleteStatus (status );
530
+ RedisModule_Free (errorMessage );
531
+ return 1 ;
532
+ }
567
533
}
568
534
569
535
TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
570
- // TODO EAGER
571
- // check status and return error
536
+ if (TF_GetCode (status ) != TF_OK ) {
537
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
538
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
539
+ TF_DeleteStatus (status );
540
+ RedisModule_Free (errorMessage );
541
+ return 1 ;
542
+ }
572
543
573
544
TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
574
- // TODO EAGER
575
- // check status and return error
545
+ if (TF_GetCode (status ) != TF_OK ) {
546
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
547
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
548
+ TF_DeleteStatus (status );
549
+ RedisModule_Free (errorMessage );
550
+ return 1 ;
551
+ }
576
552
577
553
// TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
578
554
579
555
int noutputs_ = noutputs ;
580
556
TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
581
- // TODO EAGER
582
- // check status and return error
557
+ if (TF_GetCode (status ) != TF_OK ) {
558
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
559
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
560
+ TF_DeleteStatus (status );
561
+ RedisModule_Free (errorMessage );
562
+ return 1 ;
563
+ }
583
564
584
565
for (size_t i = 0 ; i < ninputs ; ++ i ) {
585
566
TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
0 commit comments