Block sparse mm (#1058)

This commit is contained in:
Jagrit Digani
2024-05-02 14:03:58 -07:00
committed by GitHub
parent 17f57df797
commit f390957685
15 changed files with 1323 additions and 75 deletions

View File

@@ -3716,6 +3716,38 @@ void init_ops(nb::module_& m) {
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
)pbdoc");
m.def(
"block_sparse_mm",
&block_sparse_mm,
nb::arg(),
nb::arg(),
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
)pbdoc");
m.def(
"diagonal",