Skip to content

Commit c452a86

Browse files
TorchDynamo Compatability
1 parent 84cfdb2 commit c452a86

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,14 @@ def apply_weights(self,
5959
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
6060
fn: Callable) -> None:
6161
if name is not None and getattr(layer, name, None) is not None:
62-
replace_parameter(layer, name, fn(getattr(layer, name)))
62+
63+
old_param = getattr(layer, name)
64+
new_param = fn(old_param)
65+
# replace the parameter with torch.nn.Parameter for TorchDynamo
66+
# compatibility
67+
replace_parameter(
68+
layer, name,
69+
torch.nn.Parameter(new_param.data, requires_grad=False))
6370

6471
def _get_weight_params(
6572
self, layer: torch.nn.Module

0 commit comments

Comments
 (0)