faster gather qmm sorted test

This commit is contained in:
Awni Hannun
2025-08-04 17:05:36 -07:00
parent 7d86a5c108
commit 296860d2fa

View File

@@ -525,17 +525,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True),
(128, 1024, 544, 32, 4, True),
(433, 1024, 1024, 32, 4, True),
(433, 1024, 555, 32, 4, True),
(433, 2048, 1024, 32, 4, True),
(128, 1024, 1024, 32, 4, False),
(128, 1024, 544, 32, 4, False),
(433, 1024, 1024, 32, 4, False),
(433, 1024, 544, 32, 4, False),
(433, 1024, 555, 32, 4, False),
(433, 2048, 1024, 32, 4, False),
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
]
for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K)