mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
10
mlx/ops.h
10
mlx/ops.h
@@ -1185,6 +1185,16 @@ array addmm(
|
||||
const float& beta = 1.f,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix product with block masking */
|
||||
array block_masked_mm(
|
||||
array a,
|
||||
array b,
|
||||
int block_size,
|
||||
std::optional<array> mask_out = std::nullopt,
|
||||
std::optional<array> mask_lhs = std::nullopt,
|
||||
std::optional<array> mask_rhs = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
||||
Reference in New Issue
Block a user