Skip to content

add sliding window support for Gemma3 #3742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,19 @@ def index_dtype_validator(
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype not in (torch.int32, torch.int64):
if val is not None and val.dtype not in (
torch.int32,
torch.int64,
torch.bool,
):
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down
64 changes: 62 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,67 @@ def select(
return layer.get_output(0)


def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
if isinstance(tensor, (TRTTensor)):
if getattr(tensor, "meta", None) is None:
return tensor.dtype == trt.DataType.BOOL
val = tensor.meta.get("val")
if val is not None and val.dtype is torch.bool:
return True
return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool


def expand_boolean_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
for i, ind in enumerate(indices):
if ind is not None and is_boolean_tensor(ind):
_LOGGER.debug(
f"Boolean index detected at position {i}, converting with nonzero()"
)

mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")

nonzero_layer = ctx.net.add_non_zero(mask_tensor)
set_layer_name(
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
)
nonzero_indices = nonzero_layer.get_output(0)

# nonzero returns shape [N, dims], we need to extract dim i
if len(indices) == 1:
# x[mask] — 1D mask
squeeze_layer = ctx.net.add_shuffle(nonzero_indices)
squeeze_layer.reshape_dims = (-1,)
set_layer_name(
squeeze_layer,
target,
name + f"_bool_nonzero_squeeze_{i}",
source_ir,
)
squeezed_index = squeeze_layer.get_output(0)
ind = squeezed_index
else:
# Advanced multi-axis mask: extract index i from shape [N, D]
gather_axis = 1 # dim index
gather_layer = ctx.net.add_gather(
nonzero_indices,
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
gather_axis,
)
set_layer_name(
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
)
extracted_index = gather_layer.get_output(0)
ind = extracted_index
return indices


def index(
ctx: ConversionContext,
target: Target,
Expand All @@ -61,8 +122,6 @@ def index(
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
Expand All @@ -76,6 +135,7 @@ def index(
# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
last_index = None
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
for i, ind in enumerate(indices):
if ind is not None:
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
Expand Down
26 changes: 25 additions & 1 deletion tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,31 @@ def forward(self, input):
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)


class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
def test_index_input_non_dynamic_index_dynamic(self):
class TestIndexWithRuntimeIndex(torch.nn.Module):
def forward(self, x):
mask = x > 0
idx = torch.nonzero(mask, as_tuple=True)
return torch.ops.aten.index.Tensor(x, idx)

input_specs = [
Input(
min_shape=(2, 2),
opt_shape=(2, 2),
max_shape=(8, 8),
dtype=torch.float32,
),
]
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
self.run_test_with_dynamic_shape(
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tools/llm/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args):
use_fp32_acc=use_fp32_acc,
device=DEVICE,
disable_tf32=True,
use_python_runtime=True,
use_python_runtime=False,
debug=args.debug,
offload_module_to_cpu=True,
min_block_size=args.min_block_size,
Expand Down
5 changes: 1 addition & 4 deletions tools/llm/torchtrt_ext/register_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ def replace_variants_of_sdpa(
f"Unexpected number of arguments for {node.target} in the graph"
)

logger.warning(
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
)
modified_input_args = (query, key, value, None, dropout_p, True)
modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal)
# Create a new node with torch.nn.functional.scaled_dot_product_attention
# The input args is (query, key, value, is_causal). kwargs has scale
with gm.graph.inserting_after(node):
Expand Down
113 changes: 73 additions & 40 deletions tools/llm/torchtrt_ext/sdpa_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,51 @@ def tril(
name: str,
row: TRTTensor,
col: TRTTensor,
sliding_window_size: Optional[int] = None,
) -> TRTTensor:

row_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
)
row_reshape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
)

col_arange_tensor = impl.arange.arange(
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
)
col_reshape_tensor = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
row_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1
)

mask = impl.elementwise.ge(
ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
col_arange_tensor = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0
)
# sub will return the following mask tensor:
# [[0, -1, -2, -3],
# [1, 0, -1, -2],
# [2, 1, 0, -1],
# [3, 2, 1, 0]]
mask = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor
)
ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0)
if sliding_window_size is None:
# return the following lower triangular mask includes the main diagonal:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False],
# 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]])
return ge_0_mask

lt_window_mask = impl.elementwise.lt(
ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size
)
mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask
)
# return the following mask if sliding_window_size is 3:
# 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False],
# 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False],
# 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False],
# 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False],
# 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]])
return mask


Expand All @@ -66,7 +93,7 @@ def scaled_dot_product_attention(
# TODO: remove this once we have a better way to handle the causal mask
scale = kwargs.get("scale", None)
source_ir = SourceIR.ATEN
is_causal = True

# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
use_fp32_acc = kwargs.get("use_fp32_acc", False)
query_dtype = query.dtype
Expand Down Expand Up @@ -134,37 +161,43 @@ def scaled_dot_product_attention(
L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
if S < 0:
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)

# generate the mask tensor
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)

temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)

# This need_mask determines if we want to use the causal mask or not
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
# So need_mask will be all False values in this case.
# TODO: Implement more general case where L != 1 and S != L
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
temp_mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
)
temp_mask_casted = cast_trt_tensor(
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
)

one_minus_temp_mask = impl.elementwise.sub(
ctx,
target,
source_ir,
name + "_one_minus_temp_mask",
1.0,
temp_mask_casted,
)
attn_bias = impl.unary.log(
ctx, target, source_ir, name + "_log", one_minus_temp_mask
)
if is_causal:
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
temp_mask = impl.unary.logical_not(
ctx, target, source_ir, name + "_logical_not", tril_tensor
)

# This need_mask determines if we want to use the causal mask or not
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
# So need_mask will be all False values in this case.
# TODO: Implement more general case where L != 1 and S != L
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
temp_mask = impl.elementwise.logical_and(
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
)
temp_mask_casted = cast_trt_tensor(
ctx,
temp_mask,
query_dtype,
name + "_casted_bool",
target,
source_ir,
)

one_minus_temp_mask = impl.elementwise.sub(
ctx,
target,
source_ir,
name + "_one_minus_temp_mask",
1.0,
temp_mask_casted,
)
attn_bias = impl.unary.log(
ctx, target, source_ir, name + "_log", one_minus_temp_mask
)
else:
attn_bias = attn_mask

scaled_add_attn_bias = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
Expand Down
1 change: 0 additions & 1 deletion tools/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok
num_tokens_generated = 0
kv_cache = get_zeroed_dynamic_cache_inputs(model)
last_position_id = position_ids[-1, -1].item()
breakpoint()
while num_tokens_generated < num_output_tokens:
is_generate = False if input_seq.shape[1] > 1 else True
position_ids = (
Expand Down
Loading