|
10 | 10 | import onnxruntime as _ort |
11 | 11 |
|
12 | 12 |
|
| 13 | +def has_cuda(): |
| 14 | + return "CUDAExecutionProvider" in _ort.get_available_providers() |
| 15 | + |
| 16 | + |
13 | 17 | class NegXPlus1(OpRun): |
14 | 18 | op_domain = "ai.onnx.contrib" |
15 | 19 |
|
@@ -101,8 +105,6 @@ def test_cuda_fastgelu_f16(self): |
101 | 105 | print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.") |
102 | 106 |
|
103 | 107 | def _negxplus1_cuda(self, itype): |
104 | | - import onnxruntime |
105 | | - |
106 | 108 | dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 |
107 | 109 | model1 = helper.make_model( |
108 | 110 | helper.make_graph( |
@@ -137,17 +139,128 @@ def _negxplus1_cuda(self, itype): |
137 | 139 | ref = ReferenceEvaluator(model1, new_ops=[NegXPlus1]) |
138 | 140 | expected = ref.run(None, feeds1)[0] |
139 | 141 |
|
140 | | - opts = onnxruntime.SessionOptions() |
| 142 | + opts = _ort.SessionOptions() |
141 | 143 | opts.register_custom_ops_library(_get_library_path()) |
142 | | - sess = onnxruntime.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) |
| 144 | + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) |
143 | 145 | got = sess.run(None, feeds1)[0] |
144 | 146 | assert_almost_equal(expected, got, decimal=5) |
145 | 147 |
|
| 148 | + @unittest.skipIf(not has_cuda(), reason="CUDA is missing") |
146 | 149 | def test_cuda_negxplus1(self): |
147 | | - eps = _ort.get_available_providers() |
148 | | - if "CUDAExecutionProvider" in eps: |
149 | | - self._negxplus1_cuda(TensorProto.FLOAT) |
150 | | - self._negxplus1_cuda(TensorProto.FLOAT16) |
| 150 | + self._negxplus1_cuda(TensorProto.FLOAT) |
| 151 | + self._negxplus1_cuda(TensorProto.FLOAT16) |
| 152 | + |
| 153 | + def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): |
| 154 | + from onnx_extended.ortops.optim.cuda import get_ort_ext_libs |
| 155 | + |
| 156 | + model1 = helper.make_model( |
| 157 | + helper.make_graph( |
| 158 | + [ |
| 159 | + helper.make_node(op_type, ["X", "Y"], ["XY"]), |
| 160 | + helper.make_node(op_type, ["X", "Z"], ["XZ"]), |
| 161 | + ], |
| 162 | + "nd", |
| 163 | + [ |
| 164 | + helper.make_tensor_value_info("X", itype, [None, None, None]), |
| 165 | + helper.make_tensor_value_info("Y", itype, [None, None, None]), |
| 166 | + helper.make_tensor_value_info("Z", itype, [None, None, None]), |
| 167 | + ], |
| 168 | + [ |
| 169 | + helper.make_tensor_value_info("XY", itype, [None, None, None]), |
| 170 | + helper.make_tensor_value_info("XZ", itype, [None, None, None]), |
| 171 | + ], |
| 172 | + ), |
| 173 | + opset_imports=[helper.make_opsetid("", 18)], |
| 174 | + ir_version=9, |
| 175 | + ) |
| 176 | + |
| 177 | + model2 = helper.make_model( |
| 178 | + helper.make_graph( |
| 179 | + [ |
| 180 | + helper.make_node( |
| 181 | + f"{op_type}SharedInput", |
| 182 | + ["X", "Y", "Z"], |
| 183 | + ["XY", "XZ"], |
| 184 | + domain="onnx_extended.ortops.optim.cuda", |
| 185 | + ) |
| 186 | + ], |
| 187 | + "nd", |
| 188 | + [ |
| 189 | + helper.make_tensor_value_info("X", itype, [None, None, None]), |
| 190 | + helper.make_tensor_value_info("Y", itype, [None, None, None]), |
| 191 | + helper.make_tensor_value_info("Z", itype, [None, None, None]), |
| 192 | + ], |
| 193 | + [ |
| 194 | + helper.make_tensor_value_info("XY", itype, [None, None, None]), |
| 195 | + helper.make_tensor_value_info("XZ", itype, [None, None, None]), |
| 196 | + ], |
| 197 | + ), |
| 198 | + opset_imports=[ |
| 199 | + helper.make_opsetid("", 18), |
| 200 | + helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), |
| 201 | + ], |
| 202 | + ir_version=9, |
| 203 | + ) |
| 204 | + |
| 205 | + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 |
| 206 | + x = (np.arange(np.prod(shapea)) + 1).reshape((shapea)).astype(dtype) |
| 207 | + y = (np.arange(np.prod(shapeb)) + 2).reshape((shapeb)).astype(dtype) |
| 208 | + z = (np.arange(np.prod(shapec)) + 3).reshape((shapec)).astype(dtype) |
| 209 | + |
| 210 | + feeds1 = dict(X=x, Y=y, Z=z) |
| 211 | + ref = ReferenceEvaluator(model1) |
| 212 | + expected = ref.run(None, feeds1) |
| 213 | + |
| 214 | + opts = _ort.SessionOptions() |
| 215 | + opts.register_custom_ops_library(get_ort_ext_libs()[0]) |
| 216 | + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) |
| 217 | + got = sess.run(None, feeds1) |
| 218 | + for i in range(2): |
| 219 | + assert_almost_equal(expected[i], got[i]) |
| 220 | + |
| 221 | + @unittest.skipIf(not has_cuda(), reason="CUDA is missing") |
| 222 | + def test_add_shared_input_cuda(self): |
| 223 | + self._addmul_shared_input_cuda(TensorProto.FLOAT, "Add") |
| 224 | + self._addmul_shared_input_cuda(TensorProto.FLOAT16, "Add") |
| 225 | + |
| 226 | + @unittest.skipIf(not has_cuda(), reason="CUDA is missing") |
| 227 | + def test_mul_shared_input_cuda(self): |
| 228 | + self._addmul_shared_input_cuda(TensorProto.FLOAT, "Mul") |
| 229 | + self._addmul_shared_input_cuda(TensorProto.FLOAT16, "Mul") |
| 230 | + |
| 231 | + @unittest.skipIf(not has_cuda(), reason="CUDA is missing") |
| 232 | + def test_add_shared_input_cuda_broadcast1(self): |
| 233 | + self._addmul_shared_input_cuda( |
| 234 | + TensorProto.FLOAT, |
| 235 | + "Add", |
| 236 | + shapea=(3, 2, 3), |
| 237 | + shapeb=(1, 2, 3), |
| 238 | + shapec=(1, 2, 3), |
| 239 | + ) |
| 240 | + self._addmul_shared_input_cuda( |
| 241 | + TensorProto.FLOAT16, |
| 242 | + "Add", |
| 243 | + shapea=(3, 2, 3), |
| 244 | + shapeb=(1, 2, 3), |
| 245 | + shapec=(1, 2, 3), |
| 246 | + ) |
| 247 | + |
| 248 | + @unittest.skipIf(not has_cuda(), reason="CUDA is missing") |
| 249 | + def test_add_shared_input_cuda_broadcast2(self): |
| 250 | + self._addmul_shared_input_cuda( |
| 251 | + TensorProto.FLOAT, |
| 252 | + "Add", |
| 253 | + shapea=(1, 2, 3), |
| 254 | + shapeb=(3, 2, 3), |
| 255 | + shapec=(3, 2, 3), |
| 256 | + ) |
| 257 | + self._addmul_shared_input_cuda( |
| 258 | + TensorProto.FLOAT16, |
| 259 | + "Add", |
| 260 | + shapea=(1, 2, 3), |
| 261 | + shapeb=(3, 2, 3), |
| 262 | + shapec=(3, 2, 3), |
| 263 | + ) |
151 | 264 |
|
152 | 265 |
|
153 | 266 | if __name__ == "__main__": |
|
0 commit comments