diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f1a051665..2c62c6307 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -571,10 +571,10 @@ class TestQuantized(mlx_tests.MLXTestCase): sorted_indices=sort, ) - x = mx.random.normal((128, 1, 1024)) - w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024))) - indices = mx.sort(mx.random.randint(0, 8, shape=(128,))) - cotan = mx.random.normal((128, 1, 1024)) + x = mx.random.normal((16, 1, 256)) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) + cotan = mx.random.normal((16, 1, 256)) (o1,), (dx1, ds1, db1) = mx.vjp( lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),