Add more tests and fix qmm gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-05 02:41:39 -07:00
parent 3d4174cd37
commit 9e5bb5295a
2 changed files with 56 additions and 8 deletions

View File

@@ -3233,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); "[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else { } else {
if (!dsb) { if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream()); int ndim = primals[1].ndim();
auto fx = flatten(primals[0], 0, -2, stream()); auto fc = flatten(cotangents[0], 0, -ndim, stream());
auto fx = flatten(primals[0], 0, -ndim, stream());
auto dw = transpose_ auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream()) ? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream()); : matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
@@ -3388,6 +3389,7 @@ std::vector<array> GatherQMM::vjp(
vjps.push_back( vjps.push_back(
sum(multiply( sum(multiply(
*dsb, *dsb,
unflatten(
dequantize( dequantize(
w, w,
ones_like(scales, stream()), ones_like(scales, stream()),
@@ -3395,6 +3397,9 @@ std::vector<array> GatherQMM::vjp(
group_size_, group_size_,
bits_, bits_,
stream()), stream()),
-1,
{-1, group_size_},
stream()),
stream()), stream()),
-1, -1,
false, false,

View File

@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
if lhs is not None:
x = x[lhs]
if rhs is not None:
w = w[rhs]
s = s[rhs]
b = b[rhs]
return mx.quantized_matmul(x, w, s, b, transpose=trans)
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
return mx.gather_qmm(
x,
w,
s,
b,
transpose=trans,
lhs_indices=lhs,
rhs_indices=rhs,
sorted_indices=sort,
)
x = mx.random.normal((128, 1, 1024))
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
cotan = mx.random.normal((128, 1, 1024))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
(o2,), (dx2, ds2, db2) = mx.vjp(
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
def test_vjp_scales_biases(self): def test_vjp_scales_biases(self):
mx.random.seed(0) mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512)) x = mx.random.normal(shape=(2, 2, 512))