mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-23 05:58:09 +08:00
Block sparse mm (#1058)
This commit is contained in:
@@ -3716,6 +3716,38 @@ void init_ops(nb::module_& m) {
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_sparse_mm",
|
||||
&block_sparse_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with matrix-level gather.
|
||||
|
||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
||||
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
|
||||
|
||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
||||
|
||||
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
|
||||
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
||||
|
||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
|
||||
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
|
Reference in New Issue
Block a user