@@ -114,36 +114,41 @@ def create_return_code_frome_schema(schema, allow_return_ref=True):
114
114
def create_transform_input_to_cpu_code (fun_config ):
115
115
input_process_code = ""
116
116
schema = fun_config ["schema" ]
117
+ opname = get_op_name_from_schema (schema )
117
118
inputs = re .findall ("Tensor +([\w\d_]+)" , schema [: schema .find ("->" )])
118
119
for input in inputs :
119
120
input_process_code += (
120
- f"at::Tensor { input } _cpu = to_cpu_without_diopi ({ input } );\n "
121
+ f"at::Tensor { input } _cpu = toCpuTensorWithoutDiopiCopy ({ input } );\n "
121
122
)
122
123
123
124
optional_inputs = re .findall ("Tensor *\? +([\w\d_]+)" , schema [: schema .find ("->" )])
124
125
for input in optional_inputs :
125
- input_process_code += f"\n c10::optional<at::Tensor> { input } _cpu = { input } .has_value() && { input } .value().defined() ? c10::make_optional<at::Tensor>(to_cpu_without_diopi ({ input } .value())) : { input } ;\n "
126
+ input_process_code += f"\n c10::optional<at::Tensor> { input } _cpu = { input } .has_value() && { input } .value().defined() ? c10::make_optional<at::Tensor>(toCpuTensorWithoutDiopiCopy ({ input } .value())) : { input } ;\n "
126
127
127
128
optional_tensor_list_inputs = re .findall (
128
129
"Tensor *\? *\[ *\] +([\w\d_]+)" , schema [: schema .find ("->" )]
129
130
)
130
131
for input in optional_tensor_list_inputs :
131
132
input_process_code += f"\n c10::List<c10::optional<at::Tensor>> { input } _cpu;\n "
132
133
input_process_code += f"for (int i = 0; i < { input } .size();++i)" + " {\n "
133
- input_process_code += f" { input } _cpu.push_back({ input } [i].has_value() && { input } [i].value().defined() ? c10::make_optional<at::Tensor>(({ input } [i].value())) : { input } [i]);\n "
134
+ input_process_code += f" { input } _cpu.push_back({ input } [i].has_value() && { input } [i].value().defined() ? c10::make_optional<at::Tensor>(toCpuTensorWithoutDiopiCopy ({ input } [i].value())) : { input } [i]);\n "
134
135
input_process_code += "}\n "
135
136
136
137
outputs = re .findall (
137
138
"Tensor\([a-z]!\)[ ]+([\w\d_]+){1}" , schema [: schema .find ("->" )]
138
139
)
139
140
for output in outputs :
140
- if output .strip ().endswith ("?" ):
141
- output = output .replace ("?" , "" )
142
- input_process_code += f"\n c10::optional<at::Tensor> { output } _cpu = { output } .has_value() && { output } .value().defined() ? c10::make_optional<at::Tensor>(to_cpu_without_diopi({ output } .value()) : { output } ;\n "
143
- else :
144
- input_process_code += (
145
- f"at::Tensor { output } _cpu = to_cpu_without_diopi({ output } );\n "
146
- )
141
+ input_process_code += (
142
+ f"at::Tensor { output } _cpu = toCpuTensorWithoutDiopiCopy({ output } );\n "
143
+ )
144
+ if ".out" in opname or "_out" in opname :
145
+ for i in range (len (inputs )):
146
+ input_process_code += (
147
+ f"if (({ inputs [i ]} .data_ptr()) == { output } .data_ptr())"
148
+ )
149
+ input_process_code += "{\n \t "
150
+ input_process_code += f"{ inputs [i ]} _cpu = { output } _cpu;\n \t "
151
+ input_process_code += "}\n "
147
152
148
153
tensors_arrays = re .findall (
149
154
"Tensor *\[ *\] * +([\w\d_]+)" , schema [: schema .find ("->" )]
@@ -161,9 +166,8 @@ def create_transform_input_to_cpu_code(fun_config):
161
166
)
162
167
input_process_code += (
163
168
f"std::transform({ tensors_arg } .begin(), { tensors_arg } .end(), { tensors_arg } _cpu.begin(), [](const at::Tensor& tensor)"
164
- + "{return to_cpu_without_diopi (tensor);});\n "
169
+ + "{return toCpuTensorWithoutDiopiCopy (tensor);});\n "
165
170
)
166
-
167
171
return input_process_code
168
172
169
173
@@ -487,6 +491,9 @@ def create_call_aten_cpu_cpp_function_code_from_config(fun_config):
487
491
code ,
488
492
)
489
493
494
+ if "device" in code :
495
+ code = code .replace ("device" , "at::kCPU" )
496
+
490
497
inputs = re .findall ("Tensor +([\w\d_]+)" , schema [: schema .find ("->" )])
491
498
optional_inputs = re .findall ("Tensor *\? +([\w\d_]+)" , schema [: schema .find ("->" )])
492
499
outputs = re .findall (
@@ -550,7 +557,6 @@ def create_result_compare_code(fun_config):
550
557
for i in range (len (inputs )):
551
558
code += separator_code
552
559
code += f'std::cout << "autocompare:\t { op_name } \t { inputs [i ]} : " << std::endl << allclose_autocompare({ inputs [i ]} _cpu, { inputs [i ]} ) << std::endl;\n '
553
-
554
560
return code
555
561
556
562
@@ -972,9 +978,12 @@ def functions_code_gen(fun_config):
972
978
973
979
974
980
def boolean_string (s ):
975
- if s not in {"False" , "True" }:
976
- raise ValueError ("Not a valid boolean string" )
977
- return s == "True"
981
+ if s .lower () in ["true" , "on" ]:
982
+ return True
983
+ elif s .lower () in ["false" , "off" ]:
984
+ return False
985
+ else :
986
+ raise ValueError ("Not a valid boolean string." )
978
987
979
988
980
989
def parse_args ():
0 commit comments