mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
fix quantized vjp for mxfp4 (#2555)
This commit is contained in:
@@ -842,6 +842,37 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
num_ds = (out_up - out_down) / (2 * eps)
|
||||
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
|
||||
|
||||
def test_mxfp4_vjp_scales_throws(self):
|
||||
mx.random.seed(0)
|
||||
x = mx.random.normal(shape=(2, 512))
|
||||
w = mx.random.normal(shape=(512, 512))
|
||||
wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4")
|
||||
|
||||
def mm(s, x, wq):
|
||||
return mx.quantized_matmul(
|
||||
x, wq, s, bits=4, group_size=32, mode="mxfp4"
|
||||
).sum()
|
||||
|
||||
# Should raise
|
||||
with self.assertRaises(ValueError):
|
||||
ds = mx.grad(mm)(s, x, wq)
|
||||
|
||||
rhs_indices = mx.array(0)
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
def gmm(s, x, wq):
|
||||
return mx.gather_qmm(
|
||||
x,
|
||||
wq,
|
||||
s,
|
||||
rhs_indices=rhs_indices,
|
||||
bits=4,
|
||||
group_size=32,
|
||||
mode="mxfp4",
|
||||
).sum()
|
||||
|
||||
ds = mx.grad(gmm)(s, x, wq)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user