Masked mm (#978)

* Add block masked matmul op and primitive
This commit is contained in:
Jagrit Digani
2024-04-16 14:45:39 -07:00
committed by GitHub
parent 107ba2891a
commit b18468bf81
15 changed files with 1137 additions and 2 deletions

View File

@@ -1185,6 +1185,16 @@ array addmm(
const float& beta = 1.f,
StreamOrDevice s = {});
/** Compute matrix product with block masking */
array block_masked_mm(
array a,
array b,
int block_size,
std::optional<array> mask_out = std::nullopt,
std::optional<array> mask_lhs = std::nullopt,
std::optional<array> mask_rhs = std::nullopt,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,