faster gather qmm sorted test (#2463)

This commit is contained in:
Awni Hannun
2025-08-05 06:27:40 -07:00
committed by GitHub
parent ca973d1e83
commit fa89f0b150

View File

@@ -525,17 +525,17 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [ parameters = [
# L, K, D, E, I, transpose # L, K, D, E, I, transpose
(128, 1024, 1024, 32, 4, True), (32, 512, 512, 4, 2, True),
(128, 1024, 544, 32, 4, True), (32, 512, 544, 4, 2, True),
(433, 1024, 1024, 32, 4, True), (133, 512, 512, 4, 2, True),
(433, 1024, 555, 32, 4, True), (133, 512, 555, 4, 2, True),
(433, 2048, 1024, 32, 4, True), (133, 512, 512, 4, 2, True),
(128, 1024, 1024, 32, 4, False), (64, 512, 512, 4, 2, False),
(128, 1024, 544, 32, 4, False), (64, 512, 544, 4, 2, False),
(433, 1024, 1024, 32, 4, False), (133, 512, 512, 4, 2, False),
(433, 1024, 544, 32, 4, False), (133, 512, 544, 4, 2, False),
(433, 1024, 555, 32, 4, False), (133, 512, 555, 4, 2, False),
(433, 2048, 1024, 32, 4, False), (64, 512, 512, 4, 2, False),
] ]
for L, K, D, E, I, transpose in parameters: for L, K, D, E, I, transpose in parameters:
K, D = (K, D) if transpose else (D, K) K, D = (K, D) if transpose else (D, K)