diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index de43ec26d5..90a57221fc 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -525,17 +525,17 @@ class TestQuantized(mlx_tests.MLXTestCase): parameters = [ # L, K, D, E, I, transpose - (128, 1024, 1024, 32, 4, True), - (128, 1024, 544, 32, 4, True), - (433, 1024, 1024, 32, 4, True), - (433, 1024, 555, 32, 4, True), - (433, 2048, 1024, 32, 4, True), - (128, 1024, 1024, 32, 4, False), - (128, 1024, 544, 32, 4, False), - (433, 1024, 1024, 32, 4, False), - (433, 1024, 544, 32, 4, False), - (433, 1024, 555, 32, 4, False), - (433, 2048, 1024, 32, 4, False), + (32, 512, 512, 4, 2, True), + (32, 512, 544, 4, 2, True), + (133, 512, 512, 4, 2, True), + (133, 512, 555, 4, 2, True), + (133, 512, 512, 4, 2, True), + (64, 512, 512, 4, 2, False), + (64, 512, 544, 4, 2, False), + (133, 512, 512, 4, 2, False), + (133, 512, 544, 4, 2, False), + (133, 512, 555, 4, 2, False), + (64, 512, 512, 4, 2, False), ] for L, K, D, E, I, transpose in parameters: K, D = (K, D) if transpose else (D, K)