mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add more tests and fix qmm gradient
This commit is contained in:
@@ -3233,8 +3233,9 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||||
} else {
|
} else {
|
||||||
if (!dsb) {
|
if (!dsb) {
|
||||||
auto fc = flatten(cotangents[0], 0, -2, stream());
|
int ndim = primals[1].ndim();
|
||||||
auto fx = flatten(primals[0], 0, -2, stream());
|
auto fc = flatten(cotangents[0], 0, -ndim, stream());
|
||||||
|
auto fx = flatten(primals[0], 0, -ndim, stream());
|
||||||
auto dw = transpose_
|
auto dw = transpose_
|
||||||
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
? matmul(swapaxes(fc, -1, -2, stream()), fx, stream())
|
||||||
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
: matmul(swapaxes(fx, -1, -2, stream()), fc, stream());
|
||||||
@@ -3388,12 +3389,16 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
vjps.push_back(
|
vjps.push_back(
|
||||||
sum(multiply(
|
sum(multiply(
|
||||||
*dsb,
|
*dsb,
|
||||||
dequantize(
|
unflatten(
|
||||||
w,
|
dequantize(
|
||||||
ones_like(scales, stream()),
|
w,
|
||||||
zeros_like(biases, stream()),
|
ones_like(scales, stream()),
|
||||||
group_size_,
|
zeros_like(biases, stream()),
|
||||||
bits_,
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
stream()),
|
||||||
|
-1,
|
||||||
|
{-1, group_size_},
|
||||||
stream()),
|
stream()),
|
||||||
stream()),
|
stream()),
|
||||||
-1,
|
-1,
|
||||||
|
|||||||
@@ -549,6 +549,49 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||||
|
|
||||||
|
def test_gather_qmm_grad(self):
|
||||||
|
def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort):
|
||||||
|
if lhs is not None:
|
||||||
|
x = x[lhs]
|
||||||
|
if rhs is not None:
|
||||||
|
w = w[rhs]
|
||||||
|
s = s[rhs]
|
||||||
|
b = b[rhs]
|
||||||
|
return mx.quantized_matmul(x, w, s, b, transpose=trans)
|
||||||
|
|
||||||
|
def gather_qmm(x, w, s, b, lhs, rhs, trans, sort):
|
||||||
|
return mx.gather_qmm(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
transpose=trans,
|
||||||
|
lhs_indices=lhs,
|
||||||
|
rhs_indices=rhs,
|
||||||
|
sorted_indices=sort,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = mx.random.normal((128, 1, 1024))
|
||||||
|
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
|
||||||
|
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
|
||||||
|
cotan = mx.random.normal((128, 1, 1024))
|
||||||
|
|
||||||
|
(o1,), (dx1, ds1, db1) = mx.vjp(
|
||||||
|
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),
|
||||||
|
[x, s, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
(o2,), (dx2, ds2, db2) = mx.vjp(
|
||||||
|
lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True),
|
||||||
|
[x, s, b],
|
||||||
|
[cotan],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(o1, o2, atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))
|
||||||
|
self.assertTrue(mx.allclose(db1, db2, atol=1e-3))
|
||||||
|
|
||||||
def test_vjp_scales_biases(self):
|
def test_vjp_scales_biases(self):
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
x = mx.random.normal(shape=(2, 2, 512))
|
x = mx.random.normal(shape=(2, 2, 512))
|
||||||
|
|||||||
Reference in New Issue
Block a user