Enable vjp for quantized scale and bias (#2129)

* Enable vjp for quantized scale and bias

* higher tol
This commit is contained in:
Awni Hannun
2025-04-29 13:03:09 -07:00
committed by GitHub
parent b36dd472bb
commit 7bb063bcb3
2 changed files with 53 additions and 2 deletions

View File

@@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
def test_vjp_scales_biases(self):
mx.random.seed(0)
x = mx.random.normal(shape=(2, 2, 512))
w = mx.random.normal(shape=(512, 512))
wq, s, b = mx.quantize(w, bits=4, group_size=64)
def mm(sb, x, wq):
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
params = (s, b)
dparams = mx.grad(mm)((s, b), x, wq)
eps = 8e-3
# numerical grad check with a few indices
indices = [(0, 0), (11, 4), (22, 7)]
for idx in indices:
for p in [0, 1]:
params[p][idx] += eps
out_up = mm(params, x, wq)
params[p][idx] -= 2 * eps
out_down = mm(params, x, wq)
params[p][idx] += eps
num_ds = (out_up - out_down) / (2 * eps)
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
if __name__ == "__main__":
unittest.main()