diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3380fa08b..b2b7306dd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3233,8 +3233,9 @@ std::vector QuantizedMatmul::vjp( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { if (!dsb) { - auto fc = flatten(cotangents[0], 0, -2, stream()); - auto fx = flatten(primals[0], 0, -2, stream()); + int ndim = primals[1].ndim(); + auto fc = flatten(cotangents[0], 0, -ndim, stream()); + auto fx = flatten(primals[0], 0, -ndim, stream()); auto dw = transpose_ ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) : matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); @@ -3388,12 +3389,16 @@ std::vector GatherQMM::vjp( vjps.push_back( sum(multiply( *dsb, - dequantize( - w, - ones_like(scales, stream()), - zeros_like(biases, stream()), - group_size_, - bits_, + unflatten( + dequantize( + w, + ones_like(scales, stream()), + zeros_like(biases, stream()), + group_size_, + bits_, + stream()), + -1, + {-1, group_size_}, stream()), stream()), -1, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f402bd444..f1a051665 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + def test_gather_qmm_grad(self): + def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): + if lhs is not None: + x = x[lhs] + if rhs is not None: + w = w[rhs] + s = s[rhs] + b = b[rhs] + return mx.quantized_matmul(x, w, s, b, transpose=trans) + + def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): + return mx.gather_qmm( + x, + w, + s, + b, + transpose=trans, + lhs_indices=lhs, + rhs_indices=rhs, + sorted_indices=sort, + ) + + x = mx.random.normal((128, 1, 1024)) + w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024))) + indices = mx.sort(mx.random.randint(0, 8, shape=(128,))) + cotan = mx.random.normal((128, 1, 1024)) + + (o1,), (dx1, ds1, db1) = mx.vjp( + lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + (o2,), (dx2, ds2, db2) = mx.vjp( + lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True), + [x, s, b], + [cotan], + ) + + self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) + self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) + self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3)) + self.assertTrue(mx.allclose(db1, db2, atol=1e-3)) + def test_vjp_scales_biases(self): mx.random.seed(0) x = mx.random.normal(shape=(2, 2, 512))