@@ -28,20 +28,17 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
28
28
return REDISMODULE_OK ;
29
29
}
30
30
31
- // Managing context for the DLManagedTensor, will manage the lifetime of
32
- // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
33
- // original framework of destruction, and this context will be deleted also.
34
- struct TfDlManagedTensorCtx {
31
+ struct TFDLManagedTensorCtx {
35
32
TFE_TensorHandle * reference ;
36
33
int64_t ndim ;
37
34
int64_t * shape ;
38
35
int64_t * strides ;
39
36
DLManagedTensor tensor ;
40
37
};
41
- typedef struct TfDlManagedTensorCtx TfDlManagedTensorCtx ;
38
+ typedef struct TFDLManagedTensorCtx TFDLManagedTensorCtx ;
42
39
43
- TfDlManagedTensorCtx * TfDlManagedTensorCtx_Create (TFE_TensorHandle * h , TF_Status * status ) {
44
- TfDlManagedTensorCtx * ctx = RedisModule_Alloc (sizeof (TfDlManagedTensorCtx ));
40
+ TFDLManagedTensorCtx * TFDLManagedTensorCtx_Create (TFE_TensorHandle * h , TF_Status * status ) {
41
+ TFDLManagedTensorCtx * ctx = RedisModule_Alloc (sizeof (TFDLManagedTensorCtx ));
45
42
ctx -> ndim = TFE_TensorHandleNumDims (h , status );
46
43
ctx -> shape = RedisModule_Calloc (ctx -> ndim , sizeof (int64_t ));
47
44
ctx -> strides = RedisModule_Calloc (ctx -> ndim , sizeof (int64_t ));
@@ -55,23 +52,19 @@ TfDlManagedTensorCtx *TfDlManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status
55
52
return ctx ;
56
53
}
57
54
58
- void TfDlManagedTensorCtx_Free ( TfDlManagedTensorCtx * ctx ) {
55
+ void TFDLManagedTensorCtx_Free ( TFDLManagedTensorCtx * ctx ) {
59
56
RedisModule_Free (ctx -> shape );
60
57
RedisModule_Free (ctx -> strides );
61
58
RedisModule_Free (ctx );
62
59
}
63
60
64
- // Deleter for DLManagedTensor
65
61
void DLManagedTensorDeleter (DLManagedTensor * arg ) {
66
- TfDlManagedTensorCtx * owner = (TfDlManagedTensorCtx * )(arg -> manager_ctx );
67
-
68
- // TODO: check if we need to deleted the actual tensor as well
62
+ TFDLManagedTensorCtx * owner = (TFDLManagedTensorCtx * )(arg -> manager_ctx );
69
63
TFE_DeleteTensorHandle (owner -> reference );
70
- TfDlManagedTensorCtx_Free (owner );
64
+ TFDLManagedTensorCtx_Free (owner );
71
65
}
72
66
73
- // Converts TF_DATAType to DLPack data type.
74
- DLDataType GetDlDataType (TF_DataType data_type , TF_Status * status ) {
67
+ DLDataType GetDLDataType (TF_DataType data_type , TF_Status * status ) {
75
68
DLDataType dtype ;
76
69
dtype .lanes = 1 ;
77
70
dtype .bits = TF_DataTypeSize (data_type ) * 8 ;
@@ -104,8 +97,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status *status) {
104
97
return dtype ;
105
98
}
106
99
107
- // Gets DLPack's DLDevice from eager tensor handle.
108
- DLDevice GetDlDevice (TFE_TensorHandle * h , TF_Status * status ) {
100
+ DLDevice GetDLDevice (TFE_TensorHandle * h , TF_Status * status ) {
109
101
DLDevice device ;
110
102
const char * device_name = TFE_TensorHandleBackingDeviceName (h , status );
111
103
@@ -135,8 +127,7 @@ DLDevice GetDlDevice(TFE_TensorHandle *h, TF_Status *status) {
135
127
return device ;
136
128
}
137
129
138
- // Converts DLContext to TF device name.
139
- int DeviceNameFromDlContext (const DLDevice * device , char device_name [64 ]) {
130
+ int DeviceNameFromDLContext (const DLDevice * device , char device_name [64 ]) {
140
131
switch (device -> device_type ) {
141
132
case kDLCPU :
142
133
strcpy (device_name , "CPU:0" );
@@ -148,8 +139,7 @@ int DeviceNameFromDlContext(const DLDevice *device, char device_name[64]) {
148
139
return 1 ;
149
140
}
150
141
151
- // Converts DLPack data type to TF_DATATYPE.
152
- int TfDataTypeFromDlDataType (const DLDataType * dtype , TF_DataType * tf_dtype ) {
142
+ int TFDataTypeFromDLDataType (const DLDataType * dtype , TF_DataType * tf_dtype ) {
153
143
switch (dtype -> code ) {
154
144
case kDLUInt :
155
145
switch (dtype -> bits ) {
@@ -216,14 +206,10 @@ int TfDataTypeFromDlDataType(const DLDataType *dtype, TF_DataType *tf_dtype) {
216
206
}
217
207
}
218
208
219
- // Wraps the deleter function of DLManagedTensor to match the function signature
220
- // TFE_NewTensorHandleFromDeviceMemory.
221
209
void DeallocatorWrapperFunc (void * data , size_t len , void * dlmt_vptr ) {
222
210
TFE_CallDLManagedTensorDeleter (dlmt_vptr );
223
211
}
224
212
225
- // Checks whether the stride array matches the layout of compact, row-majored
226
- // data.
227
213
bool IsValidStrideCompactRowMajorData (int64_t * shape_arr , int64_t * stride_arr , int ndim ) {
228
214
if (ndim >= 1 && stride_arr [ndim - 1 ] != 1 ) {
229
215
return false;
@@ -244,7 +230,7 @@ void TFE_CallDLManagedTensorDeleter(void *dlm_ptr) {
244
230
}
245
231
246
232
void * TFE_HandleToDLPack (TFE_TensorHandle * h , TF_Status * status ) {
247
- DLDevice tf_dlm_device = GetDlDevice (h , status );
233
+ DLDevice tf_dlm_device = GetDLDevice (h , status );
248
234
if (TF_GetCode (status ) != TF_OK ) {
249
235
return NULL ;
250
236
}
@@ -256,12 +242,12 @@ void *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status) {
256
242
257
243
TF_DataType data_type = TFE_TensorHandleDataType (h );
258
244
259
- DLDataType tf_dlm_type = GetDlDataType (data_type , status );
245
+ DLDataType tf_dlm_type = GetDLDataType (data_type , status );
260
246
if (TF_GetCode (status ) != TF_OK ) {
261
247
return NULL ;
262
248
}
263
249
264
- TfDlManagedTensorCtx * tf_dlm_tensor_ctx = TfDlManagedTensorCtx_Create (h , status );
250
+ TFDLManagedTensorCtx * tf_dlm_tensor_ctx = TFDLManagedTensorCtx_Create (h , status );
265
251
266
252
DLManagedTensor * dlm_tensor = & tf_dlm_tensor_ctx -> tensor ;
267
253
dlm_tensor -> manager_ctx = tf_dlm_tensor_ctx ;
@@ -287,15 +273,15 @@ TFE_TensorHandle *TFE_HandleFromDLPack(void *dlm, TF_Status *status, TFE_Context
287
273
DLManagedTensor * dlmt = (DLManagedTensor * )dlm ;
288
274
DLTensor * dl_tensor = & dlmt -> dl_tensor ;
289
275
char device_name [64 ];
290
- int ret = DeviceNameFromDlContext (& dl_tensor -> device , device_name );
276
+ int ret = DeviceNameFromDLContext (& dl_tensor -> device , device_name );
291
277
if (ret != 0 ) {
292
- // tensorflow::errors::InvalidArgument(" Unsupported Device Type");
278
+ // TODO Unsupported device type
293
279
return NULL ;
294
280
}
295
281
TF_DataType dtype ;
296
- ret = TfDataTypeFromDlDataType (& dl_tensor -> dtype , & dtype );
282
+ ret = TFDataTypeFromDLDataType (& dl_tensor -> dtype , & dtype );
297
283
if (ret != 0 ) {
298
- // status->status = std::move(s);
284
+ // TODO Unsupported data type
299
285
return NULL ;
300
286
}
301
287
int num_dims = dl_tensor -> ndim ;
@@ -421,8 +407,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
421
407
uint8_t config [4 ] = {0x32 , 0x02 , 0x20 , 0x01 };
422
408
TFE_ContextOptionsSetConfig (context_opts , (void * )config , 4 , status );
423
409
424
- // TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
425
- // TFE_ContextOptionsSetAsync(context_opts, 0);
410
+ TFE_ContextOptionsSetAsync (context_opts , 0 );
426
411
TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
427
412
428
413
TFE_Context * context = TFE_NewContext (context_opts , status );
@@ -605,6 +590,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
605
590
const size_t noutputs = array_len (mctxs [0 ]-> outputs );
606
591
TFE_TensorHandle * inputTensorsHandles [ninputs ];
607
592
TFE_TensorHandle * outputTensorsHandles [noutputs ];
593
+ TFE_TensorHandle * deviceInputTensorsHandles [ninputs ];
594
+ TFE_TensorHandle * deviceOutputTensorsHandles [noutputs ];
608
595
609
596
size_t batch_sizes [nbatches ];
610
597
size_t batch_offsets [nbatches ];
@@ -655,7 +642,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
655
642
return 1 ;
656
643
}
657
644
658
- inputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
645
+ deviceInputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
659
646
inputTensorsHandles [i ], mctxs [0 ]-> model -> session , tf_devicestr , status );
660
647
661
648
if (TF_GetCode (status ) != TF_OK ) {
@@ -676,7 +663,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
676
663
return 1 ;
677
664
}
678
665
679
- TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
666
+ TFE_OpAddInputList (fn_op , deviceInputTensorsHandles , ninputs , status );
680
667
if (TF_GetCode (status ) != TF_OK ) {
681
668
char * errorMessage = RedisModule_Strdup (TF_Message (status ));
682
669
RAI_SetError (error , RAI_EMODELRUN , errorMessage );
@@ -686,7 +673,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
686
673
}
687
674
688
675
int noutputs_ = noutputs ;
689
- TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
676
+ TFE_Execute (fn_op , deviceOutputTensorsHandles , & noutputs_ , status );
690
677
if (TF_GetCode (status ) != TF_OK ) {
691
678
char * errorMessage = RedisModule_Strdup (TF_Message (status ));
692
679
RAI_SetError (error , RAI_EMODELRUN , errorMessage );
@@ -697,6 +684,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
697
684
698
685
for (size_t i = 0 ; i < ninputs ; ++ i ) {
699
686
TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
687
+ TFE_DeleteTensorHandle (deviceInputTensorsHandles [i ]);
700
688
}
701
689
702
690
if (TF_GetCode (status ) != TF_OK ) {
@@ -709,9 +697,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
709
697
710
698
for (size_t i = 0 ; i < noutputs ; ++ i ) {
711
699
outputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
712
- outputTensorsHandles [i ], mctxs [0 ]-> model -> session , "/device:CPU:0" , status );
700
+ deviceOutputTensorsHandles [i ], mctxs [0 ]-> model -> session , "/device:CPU:0" , status );
713
701
714
- // TF_Tensor* outputTensor = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
715
702
RAI_Tensor * outputTensor =
716
703
RAI_TensorCreateFromDLTensor (TFE_HandleToDLPack (outputTensorsHandles [i ], status ));
717
704
@@ -728,7 +715,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
728
715
continue ;
729
716
}
730
717
if (RAI_TensorDim (outputTensor , 0 ) != total_batch_size ) {
731
- // TF_DeleteTensor (outputTensor);
718
+ RAI_TensorFree (outputTensor );
732
719
TF_DeleteStatus (status );
733
720
RAI_SetError (error , RAI_EMODELRUN ,
734
721
"ERR Model did not generate the expected batch size" );
@@ -743,7 +730,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
743
730
mctxs [0 ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (outputTensor );
744
731
}
745
732
RAI_TensorFree (outputTensor );
746
- TFE_DeleteTensorHandle (outputTensorsHandles [i ]);
733
+ TFE_DeleteTensorHandle (deviceOutputTensorsHandles [i ]);
747
734
}
748
735
749
736
TF_DeleteStatus (status );
0 commit comments