|
1 | 1 | from ...core.client import Client |
2 | 2 | from ...core.callbacks import OpCallbacks, ForLoopCallbacks |
3 | | -from ...core.data import Op, Load, Store, AddPtr |
| 3 | +from ...core.data import Op, Load, Store, AddPtr, Dot |
4 | 4 | from .data import LoadStoreBytes |
5 | 5 | from triton.runtime.interpreter import _get_np_dtype, TensorHandle |
6 | 6 | import numpy as np |
@@ -203,6 +203,11 @@ def store_overrider(ptr, value, mask, cache_modifier, eviction_policy): |
203 | 203 | # Skip actual store |
204 | 204 | pass |
205 | 205 |
|
| 206 | + def dot_overrider(a, b, d, input_precision, max_num_imprecise_acc): |
| 207 | + # Skip actual dot operation, return zeros with same shape as d |
| 208 | + # This replaces np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data |
| 209 | + return TensorHandle(np.zeros_like(d.data), d.dtype.scalar) |
| 210 | + |
206 | 211 | def pre_addptr_callback(ptr, offset): |
207 | 212 | dtype_tt = ptr.get_element_ty() |
208 | 213 | element_bitwidth = dtype_tt.primitive_bitwidth |
@@ -232,6 +237,11 @@ def pre_addptr_callback(ptr, offset): |
232 | 237 | return OpCallbacks( |
233 | 238 | before_callback=pre_store_callback, op_overrider=store_overrider |
234 | 239 | ) |
| 240 | + elif op_type is Dot: |
| 241 | + if self.disable_load_store_skipping: |
| 242 | + return OpCallbacks() |
| 243 | + else: |
| 244 | + return OpCallbacks(op_overrider=dot_overrider) |
235 | 245 | elif op_type is AddPtr: |
236 | 246 | return OpCallbacks(before_callback=pre_addptr_callback) |
237 | 247 |
|
|
0 commit comments