CI weirdness due to large arrays

This commit is contained in:
Angelos Katharopoulos
2025-07-07 00:18:42 -07:00
parent 86dc1a2683
commit 8ea5729ee4

View File

@@ -571,10 +571,10 @@ class TestQuantized(mlx_tests.MLXTestCase):
sorted_indices=sort,
)
x = mx.random.normal((128, 1, 1024))
w, s, b = mx.quantize(mx.random.normal((8, 1024, 1024)))
indices = mx.sort(mx.random.randint(0, 8, shape=(128,)))
cotan = mx.random.normal((128, 1, 1024))
x = mx.random.normal((16, 1, 256))
w, s, b = mx.quantize(mx.random.normal((4, 256, 256)))
indices = mx.sort(mx.random.randint(0, 4, shape=(16,)))
cotan = mx.random.normal((16, 1, 256))
(o1,), (dx1, ds1, db1) = mx.vjp(
lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True),