fix output type of mxfp4 matmuls

This commit is contained in:
Awni Hannun
2025-10-28 13:31:17 -07:00
parent 5a043fd793
commit 83062b70e4
2 changed files with 29 additions and 2 deletions

View File

@@ -4111,7 +4111,11 @@ array quantized_matmul(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits); "quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
if (qmode == QuantizationMode::Affine) {
dtype = promote_types(x.dtype(), dtype); dtype = promote_types(x.dtype(), dtype);
} else {
dtype = x.dtype();
}
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
std::ostringstream msg; std::ostringstream msg;
@@ -4695,7 +4699,11 @@ array gather_qmm(
quantization_params_from_mode(qmode, group_size_, bits_); quantization_params_from_mode(qmode, group_size_, bits_);
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits);
if (qmode == QuantizationMode::Affine) {
out_type = promote_types(x.dtype(), out_type); out_type = promote_types(x.dtype(), out_type);
} else {
out_type = x.dtype();
}
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;

View File

@@ -745,6 +745,25 @@ class TestQuantized(mlx_tests.MLXTestCase):
test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs)
def test_qmm_mxfp4_type(self):
indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
for t in [mx.bfloat16, mx.float16, mx.float32]:
x = mx.random.normal((32, 256)).astype(t)
w = mx.random.normal((32, 256))
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
out = mx.quantized_matmul(x, wq, s, mode="mxfp4", group_size=32, bits=4)
self.assertEqual(out.dtype, t)
w = mx.random.normal((4, 32, 256))
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
out = mx.gather_qmm(
x, wq, s, rhs_indices=indices, mode="mxfp4", group_size=32, bits=4
)
self.assertEqual(out.dtype, t)
def test_gather_matmul_grad(self): def test_gather_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4): def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)