Masked mm (#978)

* Add block masked matmul op and primitive
This commit is contained in:
Jagrit Digani
2024-04-16 14:45:39 -07:00
committed by GitHub
parent 107ba2891a
commit b18468bf81
15 changed files with 1137 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
@@ -681,6 +681,119 @@ class TestBlas(mlx_tests.MLXTestCase):
mx.eval(c)
self.assertEqual(c.shape, (0, 0))
def test_block_masked_matmul(self):
def np_block_masked_mm(
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
):
# Get mask adjusted shapes
M = a.shape[-2]
N = b.shape[-1]
K = a.shape[-1]
# Expand mask dims
def expand_mask(mask, block_size, Y, X):
mask = np.expand_dims(mask, (-3, -1))
mask_shape = list(mask.shape)
mask_shape[-1] = block_size
x = mask_shape[-2] * block_size
mask_shape[-3] = block_size
y = mask_shape[-4] * block_size
mask = np.broadcast_to(mask, mask_shape)
mask_shape = mask_shape[:-4] + [y, x]
return mask.reshape(mask_shape)[..., :Y, :X]
if lhs_mask is not None:
lhs_mask = expand_mask(lhs_mask, block_size, M, K)
a = lhs_mask * a
if rhs_mask is not None:
rhs_mask = expand_mask(rhs_mask, block_size, K, N)
b = rhs_mask * b
out = a @ b
if out_mask is not None:
out_mask = expand_mask(out_mask, block_size, M, N)
out = out * out_mask
return out
def test_shape(M, N, K, block_size, transpose=False, np_dtype=np.float32):
with self.subTest(
M=M,
N=N,
K=K,
block_size=block_size,
np_dtype=np_dtype,
transpose=transpose,
):
tm = (M + block_size - 1) // block_size
tn = (N + block_size - 1) // block_size
tk = (K + block_size - 1) // block_size
a_np = np.random.normal(size=(M, K)).astype(np_dtype)
b_np = np.random.normal(size=(K, N)).astype(np_dtype)
a_np_mask = np.random.normal(size=(tm, tk)) < 0.0
b_np_mask = np.random.normal(size=(tk, tn)) < 0.0
out_np_mask = np.random.normal(size=(tm, tn)) < 0.0
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
)
if transpose:
b_np = np.random.normal(size=(N, K)).astype(np_dtype)
b_mx = mx.array(b_np)
b_np = b_np.T
b_mx = b_mx.T
out_np = np_block_masked_mm(
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
)
out_mx = mx.block_masked_mm(
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask
)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask)
out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
out_np = np_block_masked_mm(
a_np, b_np, block_size, None, a_np_mask, b_np_mask
)
out_mx = mx.block_masked_mm(
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask
)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
shapes = (
(16, 16, 16, 32),
(64, 64, 16, 32),
(128, 128, 128, 32),
(256, 256, 128, 64),
)
for M, N, K, block_size in shapes:
test_shape(M, N, K, block_size, transpose=False)
test_shape(M, N, K, block_size, transpose=True)
# Test gemv
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
b_np = np.random.normal(size=(64,)).astype(np.float32)
mask_np = np.array([True, False]).astype(np.bool_)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
mask_mx = mx.array(mask_np)
c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx)
c_np = a_np @ b_np
c_np[32:] = 0.0
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
if __name__ == "__main__":
unittest.main()