Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4653,6 +4653,12 @@ array qqmm(
// validate inputs
validate_qqmm_inputs(
x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);
if (stream.device == Device::gpu && metal::is_available() &&
(global_scale_x.has_value() || global_scale_w.has_value())) {
throw std::invalid_argument(
"[qqmm] Global scale (tensor-scale nvfp4) is not supported "
"on the Metal backend.");
}
// validate and extract shapes
auto [w_inner_dims, w_outer_dims] =
extract_qqmm_dims(x, w, scales_w, group_size, bits);
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,32 @@ def test_qqmv(self):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)

def test_qqmm_metal_global_scale_rejected(self):
# Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not
# implemented in the Metal qqmm kernels. mx.qqmm must reject the
# request on Metal rather than silently dropping the global scales
# in the gemv path and producing incorrect results.
if not mx.metal.is_available():
return

w = mx.random.normal(shape=(64, 64))
w_q, scales = mx.quantize(w, mode="nvfp4")
x = mx.random.normal(shape=(1, 64))
gx = mx.array(1.0, dtype=mx.float32)
gw = mx.array(1.0, dtype=mx.float32)

with self.assertRaises(ValueError):
y = mx.qqmm(
x,
w_q,
scales,
mode="nvfp4",
global_scale_x=gx,
global_scale_w=gw,
stream=mx.gpu,
)
mx.eval(y)

def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
Expand Down