Add more tests and fix qmm gradient

This commit is contained in:
Angelos Katharopoulos
2025-07-05 02:41:39 -07:00
parent 3d4174cd37
commit 9e5bb5295a
2 changed files with 56 additions and 8 deletions

View File

@@ -3233,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (!dsb) {
auto fc = flatten(cotangents[0], 0, -2, stream());
auto fx = flatten(primals[0], 0, -2, stream());
int ndim = primals[1].ndim();
auto fc = flatten(cotangents[0], 0, -ndim, stream());
auto fx = flatten(primals[0], 0, -ndim, stream());
auto dw = transpose_
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
@@ -3388,12 +3389,16 @@ std::vector<array> GatherQMM::vjp(
vjps.push_back(
sum(multiply(
*dsb,
dequantize(
w,
ones_like(scales, stream()),
zeros_like(biases, stream()),
group_size_,
bits_,
unflatten(
dequantize(
w,
ones_like(scales, stream()),
zeros_like(biases, stream()),
group_size_,
bits_,
stream()),
-1,
{-1, group_size_},
stream()),
stream()),
-1,

View File

@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_gather_qmm_grad(self):
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
if lhs is not None:
x = x[lhs]
if rhs is not None:
w = w[rhs]
s = s[rhs]
b = b[rhs]
return mx.quantized_matmul(x, w, s, b, transpose=trans)
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
return mx.gather_qmm(
x,
w,
s,
b,
transpose=trans,
lhs_indices=lhs,
rhs_indices=rhs,
sorted_indices=sort,
)
x = mx.random.normal((128, 1, 1024))
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
cotan = mx.random.normal((128, 1, 1024))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
(o2,), (dx2, ds2, db2) = mx.vjp(
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
[x, s, b],
[cotan],
)
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))