mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add more tests and fix qmm gradient
This commit is contained in:
@@ -3233,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (!dsb) {
|
||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
||||
auto fx = flatten(primals[0], 0, -2, stream());
|
||||
int ndim = primals[1].ndim();
|
||||
auto fc = flatten(cotangents[0], 0, -ndim, stream());
|
||||
auto fx = flatten(primals[0], 0, -ndim, stream());
|
||||
auto dw = transpose_
|
||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||
@@ -3388,12 +3389,16 @@ std::vector<array> GatherQMM::vjp(
|
||||
vjps.push_back(
|
||||
sum(multiply(
|
||||
*dsb,
|
||||
dequantize(
|
||||
w,
|
||||
ones_like(scales, stream()),
|
||||
zeros_like(biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
unflatten(
|
||||
dequantize(
|
||||
w,
|
||||
ones_like(scales, stream()),
|
||||
zeros_like(biases, stream()),
|
||||
group_size_,
|
||||
bits_,
|
||||
stream()),
|
||||
-1,
|
||||
{-1, group_size_},
|
||||
stream()),
|
||||
stream()),
|
||||
-1,
|
||||
|
||||
@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
|
||||
def test_gather_qmm_grad(self):
|
||||
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||
if lhs is not None:
|
||||
x = x[lhs]
|
||||
if rhs is not None:
|
||||
w = w[rhs]
|
||||
s = s[rhs]
|
||||
b = b[rhs]
|
||||
return mx.quantized_matmul(x, w, s, b, transpose=trans)
|
||||
|
||||
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
|
||||
return mx.gather_qmm(
|
||||
x,
|
||||
w,
|
||||
s,
|
||||
b,
|
||||
transpose=trans,
|
||||
lhs_indices=lhs,
|
||||
rhs_indices=rhs,
|
||||
sorted_indices=sort,
|
||||
)
|
||||
|
||||
x = mx.random.normal((128, 1, 1024))
|
||||
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
|
||||
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
|
||||
cotan = mx.random.normal((128, 1, 1024))
|
||||
|
||||
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||
[x, s, b],
|
||||
[cotan],
|
||||
)
|
||||
(o2,), (dx2, ds2, db2) = mx.vjp(
|
||||
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
|
||||
[x, s, b],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
|
||||
|
||||
def test_vjp_scales_biases(self):
|
||||
mx.random.seed(0)
|
||||
x = mx.random.normal(shape=(2, 2, 512))
|
||||
|
||||
Reference in New Issue
Block a user