Skip to content

Commit c0bce2d

Browse files
authored
[DEV][PROFILER] Add dot overriders to skip tl.dot (#221)
1 parent 36130ad commit c0bce2d

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

triton_viz/clients/profiler/profiler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ...core.client import Client
22
from ...core.callbacks import OpCallbacks, ForLoopCallbacks
3-
from ...core.data import Op, Load, Store, AddPtr
3+
from ...core.data import Op, Load, Store, AddPtr, Dot
44
from .data import LoadStoreBytes
55
from triton.runtime.interpreter import _get_np_dtype, TensorHandle
66
import numpy as np
@@ -203,6 +203,11 @@ def store_overrider(ptr, value, mask, cache_modifier, eviction_policy):
203203
# Skip actual store
204204
pass
205205

206+
def dot_overrider(a, b, d, input_precision, max_num_imprecise_acc):
207+
# Skip actual dot operation, return zeros with same shape as d
208+
# This replaces np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data
209+
return TensorHandle(np.zeros_like(d.data), d.dtype.scalar)
210+
206211
def pre_addptr_callback(ptr, offset):
207212
dtype_tt = ptr.get_element_ty()
208213
element_bitwidth = dtype_tt.primitive_bitwidth
@@ -232,6 +237,11 @@ def pre_addptr_callback(ptr, offset):
232237
return OpCallbacks(
233238
before_callback=pre_store_callback, op_overrider=store_overrider
234239
)
240+
elif op_type is Dot:
241+
if self.disable_load_store_skipping:
242+
return OpCallbacks()
243+
else:
244+
return OpCallbacks(op_overrider=dot_overrider)
235245
elif op_type is AddPtr:
236246
return OpCallbacks(before_callback=pre_addptr_callback)
237247

0 commit comments

Comments
 (0)