Gather mm new kernel and small refactoring (#2040)

This commit is contained in:
Angelos Katharopoulos
2025-04-14 16:37:36 -07:00
committed by GitHub
parent e9e268336b
commit 99eefd2ec0
23 changed files with 1260 additions and 378 deletions

View File

@@ -4464,9 +4464,10 @@ void init_ops(nb::module_& m) {
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
nb::kw_only(),
"sorted_indices"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
@@ -4485,11 +4486,16 @@ void init_ops(nb::module_& m) {
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``
contains indices from the range ``[0, B1 * B2 * ... * BS)``
If only one index is passed and it is sorted, the ``sorted_indices``
flag can be passed for a possible faster implementation.
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``
rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``
sorted_indices (bool, optional): May allow a faster implementation
if the passed indices are sorted. Default: ``False``.
Returns:
array: The output array.

View File

@@ -1108,7 +1108,7 @@ class TestBlas(mlx_tests.MLXTestCase):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2]
N = b.shape[-2]
N = b.shape[-1]
K = a.shape[-1]
a = a.reshape((-1, M, K))