Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos
2025-04-14 16:37:36 -07:00
committed by GitHub
parent e9e268336b
commit 99eefd2ec0
23 changed files with 1260 additions and 378 deletions

View File

@@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2]
N = b.shape[-2]
N = b.shape[-1]
K = a.shape[-1]
a = a.reshape((-1, M, K))