Block sparse mm (#1058)

This commit is contained in:
Jagrit Digani
2024-05-02 14:03:58 -07:00
committed by GitHub
parent 17f57df797
commit f390957685
15 changed files with 1323 additions and 75 deletions

View File

@@ -1183,6 +1183,14 @@ array block_masked_mm(
std::optional<array> mask_rhs = std::nullopt,
StreamOrDevice s = {});
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array a,
array b,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,