mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix output type of mxfp4 matmuls
This commit is contained in:
@@ -4111,7 +4111,11 @@ array quantized_matmul(
|
|||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
|
if (qmode == QuantizationMode::Affine) {
|
||||||
dtype = promote_types(x.dtype(), dtype);
|
dtype = promote_types(x.dtype(), dtype);
|
||||||
|
} else {
|
||||||
|
dtype = x.dtype();
|
||||||
|
}
|
||||||
|
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@@ -4695,7 +4699,11 @@ array gather_qmm(
|
|||||||
quantization_params_from_mode(qmode, group_size_, bits_);
|
quantization_params_from_mode(qmode, group_size_, bits_);
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
if (qmode == QuantizationMode::Affine) {
|
||||||
out_type = promote_types(x.dtype(), out_type);
|
out_type = promote_types(x.dtype(), out_type);
|
||||||
|
} else {
|
||||||
|
out_type = x.dtype();
|
||||||
|
}
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
|
|||||||
@@ -745,6 +745,25 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
test_shape(32, 512, 32, transpose=False, **kwargs)
|
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||||
test_shape(1, 512, 32, transpose=False, **kwargs)
|
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||||
|
|
||||||
|
def test_qmm_mxfp4_type(self):
|
||||||
|
indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||||
|
|
||||||
|
for t in [mx.bfloat16, mx.float16, mx.float32]:
|
||||||
|
x = mx.random.normal((32, 256)).astype(t)
|
||||||
|
|
||||||
|
w = mx.random.normal((32, 256))
|
||||||
|
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||||
|
out = mx.quantized_matmul(x, wq, s, mode="mxfp4", group_size=32, bits=4)
|
||||||
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
|
w = mx.random.normal((4, 32, 256))
|
||||||
|
wq, s = mx.quantize(w, mode="mxfp4", bits=4, group_size=32)
|
||||||
|
|
||||||
|
out = mx.gather_qmm(
|
||||||
|
x, wq, s, rhs_indices=indices, mode="mxfp4", group_size=32, bits=4
|
||||||
|
)
|
||||||
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
def test_gather_matmul_grad(self):
|
def test_gather_matmul_grad(self):
|
||||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
|
|||||||
Reference in New Issue
Block a user