mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
@@ -3645,6 +3645,44 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: ``alpha * (a @ b) + beta * c``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_masked_mm",
|
||||
&block_masked_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"block_size"_a = 64,
|
||||
"mask_out"_a = nb::none(),
|
||||
"mask_lhs"_a = nb::none(),
|
||||
"mask_rhs"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with block masking.
|
||||
|
||||
Perform the (possibly batched) matrix multiplication of two arrays and with blocks
|
||||
of size ``block_size x block_size`` optionally masked out.
|
||||
|
||||
Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)
|
||||
|
||||
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Boolean mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
&diagonal,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@@ -681,6 +681,119 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mx.eval(c)
|
||||
self.assertEqual(c.shape, (0, 0))
|
||||
|
||||
def test_block_masked_matmul(self):
|
||||
def np_block_masked_mm(
|
||||
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
|
||||
):
|
||||
# Get mask adjusted shapes
|
||||
M = a.shape[-2]
|
||||
N = b.shape[-1]
|
||||
K = a.shape[-1]
|
||||
|
||||
# Expand mask dims
|
||||
def expand_mask(mask, block_size, Y, X):
|
||||
mask = np.expand_dims(mask, (-3, -1))
|
||||
mask_shape = list(mask.shape)
|
||||
mask_shape[-1] = block_size
|
||||
x = mask_shape[-2] * block_size
|
||||
mask_shape[-3] = block_size
|
||||
y = mask_shape[-4] * block_size
|
||||
mask = np.broadcast_to(mask, mask_shape)
|
||||
mask_shape = mask_shape[:-4] + [y, x]
|
||||
return mask.reshape(mask_shape)[..., :Y, :X]
|
||||
|
||||
if lhs_mask is not None:
|
||||
lhs_mask = expand_mask(lhs_mask, block_size, M, K)
|
||||
a = lhs_mask * a
|
||||
|
||||
if rhs_mask is not None:
|
||||
rhs_mask = expand_mask(rhs_mask, block_size, K, N)
|
||||
b = rhs_mask * b
|
||||
|
||||
out = a @ b
|
||||
|
||||
if out_mask is not None:
|
||||
out_mask = expand_mask(out_mask, block_size, M, N)
|
||||
out = out * out_mask
|
||||
return out
|
||||
|
||||
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
block_size=block_size,
|
||||
np_dtype=np_dtype,
|
||||
transpose=transpose,
|
||||
):
|
||||
tm = (M + block_size - 1) // block_size
|
||||
tn = (N + block_size - 1) // block_size
|
||||
tk = (K + block_size - 1) // block_size
|
||||
|
||||
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
|
||||
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
|
||||
|
||||
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
|
||||
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
|
||||
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
|
||||
|
||||
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
|
||||
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
|
||||
)
|
||||
|
||||
if transpose:
|
||||
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
b_np = b_np.T
|
||||
b_mx = b_mx.T
|
||||
|
||||
out_np = np_block_masked_mm(
|
||||
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
|
||||
)
|
||||
out_mx = mx.block_masked_mm(
|
||||
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask
|
||||
)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask)
|
||||
out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
out_np = np_block_masked_mm(
|
||||
a_np, b_np, block_size, None, a_np_mask, b_np_mask
|
||||
)
|
||||
out_mx = mx.block_masked_mm(
|
||||
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask
|
||||
)
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
shapes = (
|
||||
(16, 16, 16, 32),
|
||||
(64, 64, 16, 32),
|
||||
(128, 128, 128, 32),
|
||||
(256, 256, 128, 64),
|
||||
)
|
||||
|
||||
for M, N, K, block_size in shapes:
|
||||
test_shape(M, N, K, block_size, transpose=False)
|
||||
test_shape(M, N, K, block_size, transpose=True)
|
||||
|
||||
# Test gemv
|
||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(64,)).astype(np.float32)
|
||||
mask_np = np.array([True, False]).astype(np.bool_)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
mask_mx = mx.array(mask_np)
|
||||
|
||||
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
|
||||
c_np = a_np @ b_np
|
||||
c_np[32:] = 0.0
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user