Skip to content

Commit 96d5372

Browse files
committed
fix load_and_run.py
1 parent b2274a5 commit 96d5372

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,18 @@ def __init__(self, graph_id, device_id):
209209
self.output_dtypes = []
210210
self.output_datasize = []
211211
for item in shapes:
212+
if item == '':
213+
self.output_shapes.append([])
214+
continue
212215
elems = item.split(',')
213216
elems = [int(x) for x in elems]
214217
self.output_shapes.append(elems)
215218
for item in dtypes:
216219
elem = int(item)
217220
self.output_dtypes.append(elem)
218221
for i in range(len(shapes)):
219-
elem_size = math.prod(self.output_shapes[i])
222+
elem_size = math.prod(self.output_shapes[i]) if len(
223+
self.output_shapes[i]) > 0 else 1
220224
self.output_datasize.append(
221225
elem_size * acl.data_type_size(self.output_dtypes[i]))
222226
self.output_datasize_c = (
@@ -242,14 +246,18 @@ def __init__(self, graph_id, device_id):
242246
self.input_datasize = []
243247

244248
for item in shapes:
249+
if item == '':
250+
self.input_shapes.append([])
251+
continue
245252
elems = item.split(',')
246253
elems = [int(x) for x in elems]
247254
self.input_shapes.append(elems)
248255
for item in dtypes:
249256
elem = int(item)
250257
self.input_dtypes.append(elem)
251258
for i in range(len(shapes)):
252-
elem_size = math.prod(self.input_shapes[i])
259+
elem_size = math.prod(self.input_shapes[i]) if len(
260+
self.input_shapes[i]) > 0 else 1
253261
self.input_datasize.append(
254262
elem_size * acl.data_type_size(self.input_dtypes[i]))
255263
self.input_datasize_c = (

dicp/dicp/vendor/AscendGraph/compile_job.py

-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def __init__(self, source_code) -> None:
122122
str(self.device_id) + code_hash(compile_file_code)
123123
)
124124
self._output_graph_path = self._input_path[:-5] + '/graph'
125-
# print('output_path: ', self._output_graph_path)
126125
self._model_path = [f'{self._output_graph_path}.om',
127126
f'{self._output_graph_path}_linux_x86_64.om']
128127
self._lib_path = "/tmp/dicp_ascend/ge_graph.so"

0 commit comments

Comments
 (0)