mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias * higher tol
This commit is contained in:
@@ -549,6 +549,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(y1, y3, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y1, y4, atol=1e-5))
|
||||
|
||||
def test_vjp_scales_biases(self):
|
||||
mx.random.seed(0)
|
||||
x = mx.random.normal(shape=(2, 2, 512))
|
||||
w = mx.random.normal(shape=(512, 512))
|
||||
wq, s, b = mx.quantize(w, bits=4, group_size=64)
|
||||
|
||||
def mm(sb, x, wq):
|
||||
return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum()
|
||||
|
||||
params = (s, b)
|
||||
dparams = mx.grad(mm)((s, b), x, wq)
|
||||
|
||||
eps = 8e-3
|
||||
# numerical grad check with a few indices
|
||||
indices = [(0, 0), (11, 4), (22, 7)]
|
||||
for idx in indices:
|
||||
for p in [0, 1]:
|
||||
params[p][idx] += eps
|
||||
out_up = mm(params, x, wq)
|
||||
params[p][idx] -= 2 * eps
|
||||
out_down = mm(params, x, wq)
|
||||
params[p][idx] += eps
|
||||
num_ds = (out_up - out_down) / (2 * eps)
|
||||
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user