mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix output type of mxfp4 matmuls
This commit is contained in:
12
mlx/ops.cpp
12
mlx/ops.cpp
@@ -4111,7 +4111,11 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
dtype = promote_types(x.dtype(), dtype);
|
||||
} else {
|
||||
dtype = x.dtype();
|
||||
}
|
||||
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -4695,7 +4699,11 @@ array gather_qmm(
|
||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
if (qmode == QuantizationMode::Affine) {
|
||||
out_type = promote_types(x.dtype(), out_type);
|
||||
} else {
|
||||
out_type = x.dtype();
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
|
||||
@@ -745,6 +745,25 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||
|
||||
def test_qmm_mxfp4_type(self):
|
||||
indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
for t in [mx.bfloat16, mx.float16, mx.float32]:
|
||||
x = mx.random.normal((32, 256)).astype(t)
|
||||
|
||||
w = mx.random.normal((32, 256))
|
||||
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||
out = mx.quantized_matmul(x, wq, s, mode="mxfp4", group_size=32, bits=4)
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
w = mx.random.normal((4, 32, 256))
|
||||
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||
|
||||
out = mx.gather_qmm(
|
||||
x, wq, s, rhs_indices=indices, mode="mxfp4", group_size=32, bits=4
|
||||
)
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def test_gather_matmul_grad(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
|
||||
Reference in New Issue
Block a user