Skip to content

Commit 0f86380

Browse files
committed
WIP memory error
1 parent e3ee54c commit 0f86380

File tree

2 files changed

+39
-45
lines changed

2 files changed

+39
-45
lines changed

src/backends/tensorflow.c

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,17 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
2828
return REDISMODULE_OK;
2929
}
3030

31-
// Managing context for the DLManagedTensor, will manage the lifetime of
32-
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
33-
// original framework of destruction, and this context will be deleted also.
34-
struct TfDlManagedTensorCtx {
31+
struct TFDLManagedTensorCtx {
3532
TFE_TensorHandle *reference;
3633
int64_t ndim;
3734
int64_t *shape;
3835
int64_t *strides;
3936
DLManagedTensor tensor;
4037
};
41-
typedef struct TfDlManagedTensorCtx TfDlManagedTensorCtx;
38+
typedef struct TFDLManagedTensorCtx TFDLManagedTensorCtx;
4239

43-
TfDlManagedTensorCtx *TfDlManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status *status) {
44-
TfDlManagedTensorCtx *ctx = RedisModule_Alloc(sizeof(TfDlManagedTensorCtx));
40+
TFDLManagedTensorCtx *TFDLManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status *status) {
41+
TFDLManagedTensorCtx *ctx = RedisModule_Alloc(sizeof(TFDLManagedTensorCtx));
4542
ctx->ndim = TFE_TensorHandleNumDims(h, status);
4643
ctx->shape = RedisModule_Calloc(ctx->ndim, sizeof(int64_t));
4744
ctx->strides = RedisModule_Calloc(ctx->ndim, sizeof(int64_t));
@@ -55,23 +52,19 @@ TfDlManagedTensorCtx *TfDlManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status
5552
return ctx;
5653
}
5754

58-
void TfDlManagedTensorCtx_Free(TfDlManagedTensorCtx *ctx) {
55+
void TFDLManagedTensorCtx_Free(TFDLManagedTensorCtx *ctx) {
5956
RedisModule_Free(ctx->shape);
6057
RedisModule_Free(ctx->strides);
6158
RedisModule_Free(ctx);
6259
}
6360

64-
// Deleter for DLManagedTensor
6561
void DLManagedTensorDeleter(DLManagedTensor *arg) {
66-
TfDlManagedTensorCtx *owner = (TfDlManagedTensorCtx *)(arg->manager_ctx);
67-
68-
// TODO: check if we need to deleted the actual tensor as well
62+
TFDLManagedTensorCtx *owner = (TFDLManagedTensorCtx *)(arg->manager_ctx);
6963
TFE_DeleteTensorHandle(owner->reference);
70-
TfDlManagedTensorCtx_Free(owner);
64+
TFDLManagedTensorCtx_Free(owner);
7165
}
7266

73-
// Converts TF_DATAType to DLPack data type.
74-
DLDataType GetDlDataType(TF_DataType data_type, TF_Status *status) {
67+
DLDataType GetDLDataType(TF_DataType data_type, TF_Status *status) {
7568
DLDataType dtype;
7669
dtype.lanes = 1;
7770
dtype.bits = TF_DataTypeSize(data_type) * 8;
@@ -104,8 +97,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status *status) {
10497
return dtype;
10598
}
10699

107-
// Gets DLPack's DLDevice from eager tensor handle.
108-
DLDevice GetDlDevice(TFE_TensorHandle *h, TF_Status *status) {
100+
DLDevice GetDLDevice(TFE_TensorHandle *h, TF_Status *status) {
109101
DLDevice device;
110102
const char *device_name = TFE_TensorHandleBackingDeviceName(h, status);
111103

@@ -135,8 +127,7 @@ DLDevice GetDlDevice(TFE_TensorHandle *h, TF_Status *status) {
135127
return device;
136128
}
137129

138-
// Converts DLContext to TF device name.
139-
int DeviceNameFromDlContext(const DLDevice *device, char device_name[64]) {
130+
int DeviceNameFromDLContext(const DLDevice *device, char device_name[64]) {
140131
switch (device->device_type) {
141132
case kDLCPU:
142133
strcpy(device_name, "CPU:0");
@@ -148,8 +139,7 @@ int DeviceNameFromDlContext(const DLDevice *device, char device_name[64]) {
148139
return 1;
149140
}
150141

151-
// Converts DLPack data type to TF_DATATYPE.
152-
int TfDataTypeFromDlDataType(const DLDataType *dtype, TF_DataType *tf_dtype) {
142+
int TFDataTypeFromDLDataType(const DLDataType *dtype, TF_DataType *tf_dtype) {
153143
switch (dtype->code) {
154144
case kDLUInt:
155145
switch (dtype->bits) {
@@ -216,14 +206,10 @@ int TfDataTypeFromDlDataType(const DLDataType *dtype, TF_DataType *tf_dtype) {
216206
}
217207
}
218208

219-
// Wraps the deleter function of DLManagedTensor to match the function signature
220-
// TFE_NewTensorHandleFromDeviceMemory.
221209
void DeallocatorWrapperFunc(void *data, size_t len, void *dlmt_vptr) {
222210
TFE_CallDLManagedTensorDeleter(dlmt_vptr);
223211
}
224212

225-
// Checks whether the stride array matches the layout of compact, row-majored
226-
// data.
227213
bool IsValidStrideCompactRowMajorData(int64_t *shape_arr, int64_t *stride_arr, int ndim) {
228214
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
229215
return false;
@@ -244,7 +230,7 @@ void TFE_CallDLManagedTensorDeleter(void *dlm_ptr) {
244230
}
245231

246232
void *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status) {
247-
DLDevice tf_dlm_device = GetDlDevice(h, status);
233+
DLDevice tf_dlm_device = GetDLDevice(h, status);
248234
if (TF_GetCode(status) != TF_OK) {
249235
return NULL;
250236
}
@@ -256,12 +242,12 @@ void *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status) {
256242

257243
TF_DataType data_type = TFE_TensorHandleDataType(h);
258244

259-
DLDataType tf_dlm_type = GetDlDataType(data_type, status);
245+
DLDataType tf_dlm_type = GetDLDataType(data_type, status);
260246
if (TF_GetCode(status) != TF_OK) {
261247
return NULL;
262248
}
263249

264-
TfDlManagedTensorCtx *tf_dlm_tensor_ctx = TfDlManagedTensorCtx_Create(h, status);
250+
TFDLManagedTensorCtx *tf_dlm_tensor_ctx = TFDLManagedTensorCtx_Create(h, status);
265251

266252
DLManagedTensor *dlm_tensor = &tf_dlm_tensor_ctx->tensor;
267253
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
@@ -287,15 +273,15 @@ TFE_TensorHandle *TFE_HandleFromDLPack(void *dlm, TF_Status *status, TFE_Context
287273
DLManagedTensor *dlmt = (DLManagedTensor *)dlm;
288274
DLTensor *dl_tensor = &dlmt->dl_tensor;
289275
char device_name[64];
290-
int ret = DeviceNameFromDlContext(&dl_tensor->device, device_name);
276+
int ret = DeviceNameFromDLContext(&dl_tensor->device, device_name);
291277
if (ret != 0) {
292-
// tensorflow::errors::InvalidArgument("Unsupported Device Type");
278+
// TODO Unsupported device type
293279
return NULL;
294280
}
295281
TF_DataType dtype;
296-
ret = TfDataTypeFromDlDataType(&dl_tensor->dtype, &dtype);
282+
ret = TFDataTypeFromDLDataType(&dl_tensor->dtype, &dtype);
297283
if (ret != 0) {
298-
// status->status = std::move(s);
284+
// TODO Unsupported data type
299285
return NULL;
300286
}
301287
int num_dims = dl_tensor->ndim;
@@ -421,8 +407,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
421407
uint8_t config[4] = {0x32, 0x02, 0x20, 0x01};
422408
TFE_ContextOptionsSetConfig(context_opts, (void *)config, 4, status);
423409

424-
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
425-
// TFE_ContextOptionsSetAsync(context_opts, 0);
410+
TFE_ContextOptionsSetAsync(context_opts, 0);
426411
TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
427412

428413
TFE_Context *context = TFE_NewContext(context_opts, status);
@@ -605,6 +590,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
605590
const size_t noutputs = array_len(mctxs[0]->outputs);
606591
TFE_TensorHandle *inputTensorsHandles[ninputs];
607592
TFE_TensorHandle *outputTensorsHandles[noutputs];
593+
TFE_TensorHandle *deviceInputTensorsHandles[ninputs];
594+
TFE_TensorHandle *deviceOutputTensorsHandles[noutputs];
608595

609596
size_t batch_sizes[nbatches];
610597
size_t batch_offsets[nbatches];
@@ -655,7 +642,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
655642
return 1;
656643
}
657644

658-
inputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
645+
deviceInputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
659646
inputTensorsHandles[i], mctxs[0]->model->session, tf_devicestr, status);
660647

661648
if (TF_GetCode(status) != TF_OK) {
@@ -676,7 +663,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
676663
return 1;
677664
}
678665

679-
TFE_OpAddInputList(fn_op, inputTensorsHandles, ninputs, status);
666+
TFE_OpAddInputList(fn_op, deviceInputTensorsHandles, ninputs, status);
680667
if (TF_GetCode(status) != TF_OK) {
681668
char *errorMessage = RedisModule_Strdup(TF_Message(status));
682669
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
@@ -686,7 +673,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
686673
}
687674

688675
int noutputs_ = noutputs;
689-
TFE_Execute(fn_op, outputTensorsHandles, &noutputs_, status);
676+
TFE_Execute(fn_op, deviceOutputTensorsHandles, &noutputs_, status);
690677
if (TF_GetCode(status) != TF_OK) {
691678
char *errorMessage = RedisModule_Strdup(TF_Message(status));
692679
RAI_SetError(error, RAI_EMODELRUN, errorMessage);
@@ -697,6 +684,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
697684

698685
for (size_t i = 0; i < ninputs; ++i) {
699686
TFE_DeleteTensorHandle(inputTensorsHandles[i]);
687+
TFE_DeleteTensorHandle(deviceInputTensorsHandles[i]);
700688
}
701689

702690
if (TF_GetCode(status) != TF_OK) {
@@ -709,9 +697,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
709697

710698
for (size_t i = 0; i < noutputs; ++i) {
711699
outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice(
712-
outputTensorsHandles[i], mctxs[0]->model->session, "/device:CPU:0", status);
700+
deviceOutputTensorsHandles[i], mctxs[0]->model->session, "/device:CPU:0", status);
713701

714-
// TF_Tensor* outputTensor = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
715702
RAI_Tensor *outputTensor =
716703
RAI_TensorCreateFromDLTensor(TFE_HandleToDLPack(outputTensorsHandles[i], status));
717704

@@ -728,7 +715,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
728715
continue;
729716
}
730717
if (RAI_TensorDim(outputTensor, 0) != total_batch_size) {
731-
// TF_DeleteTensor(outputTensor);
718+
RAI_TensorFree(outputTensor);
732719
TF_DeleteStatus(status);
733720
RAI_SetError(error, RAI_EMODELRUN,
734721
"ERR Model did not generate the expected batch size");
@@ -743,7 +730,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
743730
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(outputTensor);
744731
}
745732
RAI_TensorFree(outputTensor);
746-
TFE_DeleteTensorHandle(outputTensorsHandles[i]);
733+
TFE_DeleteTensorHandle(deviceOutputTensorsHandles[i]);
747734
}
748735

749736
TF_DeleteStatus(status);

tests/flow/tests_dag.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -814,12 +814,15 @@ def test_dag_modelrun_financialNet_no_writes(env):
814814
env)
815815
model_name = 'financialNet{{hhh}}'
816816

817-
ret = con.execute_command('AI.MODELSET', model_name, 'TF', "CPU",
817+
ret = con.execute_command('AI.MODELSET', model_name, 'TF', "GPU",
818818
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
819819
env.assertEqual(ret, b'OK')
820820

821-
for tensor_number in range(1,MAX_TRANSACTIONS):
822-
for repetition in range(1,10):
821+
MAX_TRANSACTIONS = 2
822+
823+
for tensor_number in range(1, MAX_TRANSACTIONS):
824+
# for repetition in range(1, 10):
825+
for repetition in range(1, 2):
823826
reference_tensor = creditcard_referencedata[tensor_number]
824827
transaction_tensor = creditcard_transactions[tensor_number]
825828
result_tensor_keyname = 'resultTensor{{hhh}}{}'.format(tensor_number)
@@ -833,20 +836,24 @@ def test_dag_modelrun_financialNet_no_writes(env):
833836
ret = con.execute_command("EXISTS {}".format(reference_tensor_keyname))
834837
env.assertEqual(ret, 1)
835838

839+
# print(reference_tensor)
840+
print(transaction_tensor)
841+
836842
ret = con.execute_command(
837843
'AI.DAGRUN', 'LOAD', '1', reference_tensor_keyname, '|>',
838844
'AI.TENSORSET', transaction_tensor_keyname, 'FLOAT', 1, 30,'BLOB', transaction_tensor.tobytes(), '|>',
839845
'AI.MODELRUN', model_name,
840846
'INPUTS', transaction_tensor_keyname, reference_tensor_keyname,
841847
'OUTPUTS', result_tensor_keyname, '|>',
842-
'AI.TENSORGET',result_tensor_keyname, 'META', '|>',
848+
'AI.TENSORGET', result_tensor_keyname, 'META', '|>',
843849
'AI.TENSORGET', result_tensor_keyname, 'VALUES'
844850
)
845851
env.assertEqual(4, len(ret))
846852
env.assertEqual([b'OK', b'OK'], ret[:2])
847853
env.assertEqual([b'dtype', b'FLOAT', b'shape', [1, 2]], ret[2])
848854
values = ret[3]
849855
# Assert that resulting classification is within [0,1]
856+
print(values)
850857
env.assertEqual(True, 0 <= float(values[0]) <= 1)
851858
env.assertEqual(True, 0 <= float(values[1]) <= 1)
852859

0 commit comments

Comments
 (0)