mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
@@ -3645,6 +3645,44 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: ``alpha * (a @ b) + beta * c``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_masked_mm",
|
||||
&block_masked_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"block_size"_a = 64,
|
||||
"mask_out"_a = nb::none(),
|
||||
"mask_lhs"_a = nb::none(),
|
||||
"mask_rhs"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with block masking.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays and with blocks
|
||||
of size ``block_size x block_size`` optionally masked out.
|
||||
|
||||
Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)
|
||||
|
||||
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Boolean mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
&diagonal,
|
||||
|
Reference in New Issue
Block a user