Float mask update (#1152)

* Float mask update

* Update CPU impl
This commit is contained in:
Jagrit Digani
2024-05-23 17:20:44 -07:00
committed by GitHub
parent 50dfb664db
commit eab2685c67
8 changed files with 713 additions and 253 deletions

View File

@@ -682,7 +682,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertEqual(c.shape, (0, 0))
def test_block_masked_matmul(self):
def np_block_masked_mm(
def ref_block_masked_mm(
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
):
# Get mask adjusted shapes
@@ -690,33 +690,81 @@ class TestBlas(mlx_tests.MLXTestCase):
N = b.shape[-1]
K = a.shape[-1]
bsx_shape = np.broadcast_shapes(a.shape[:-2], b.shape[:-2])
# Expand mask dims
def expand_mask(mask, block_size, Y, X):
mask = np.expand_dims(mask, (-3, -1))
mask_shape = list(mask.shape)
mask = mx.expand_dims(mask, (-3, -1))
mask_shape = list(bsx_shape) + list(mask.shape[-4:])
mask_shape[-1] = block_size
x = mask_shape[-2] * block_size
mask_shape[-3] = block_size
y = mask_shape[-4] * block_size
mask = np.broadcast_to(mask, mask_shape)
mask = mx.broadcast_to(mask, mask_shape)
mask_shape = mask_shape[:-4] + [y, x]
return mask.reshape(mask_shape)[..., :Y, :X]
a_masked = a
b_masked = b
if lhs_mask is not None:
lhs_mask = expand_mask(lhs_mask, block_size, M, K)
a = lhs_mask * a
lhs_mask = expand_mask(lhs_mask, block_size, M, K).astype(mx.float32)
a_masked = lhs_mask * a_masked
if rhs_mask is not None:
rhs_mask = expand_mask(rhs_mask, block_size, K, N)
b = rhs_mask * b
rhs_mask = expand_mask(rhs_mask, block_size, K, N).astype(mx.float32)
b_masked = rhs_mask * b_masked
out = a @ b
out = a_masked @ b_masked
if out_mask is not None:
out_mask = expand_mask(out_mask, block_size, M, N)
out_mask = expand_mask(out_mask, block_size, M, N).astype(mx.float32)
out = out * out_mask
return out
def run_test(a, b, block_size, out_mask, a_mask, b_mask, cotan):
def f_ref(a_, b_):
return ref_block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)
def f_test(a_, b_):
return mx.block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)
out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])
out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])
mx.eval((out_ref, dout_ref, out_test, dout_test))
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):
def f_ref(a_, b_, a_mask_, b_mask_):
return ref_block_masked_mm(
a_, b_, block_size, out_mask, a_mask_, b_mask_
)
def f_test(a_, b_, a_mask_, b_mask_):
return mx.block_masked_mm(
a_, b_, block_size, out_mask, a_mask_, b_mask_
)
out_ref, dout_ref = mx.vjp(f_ref, [a, b, a_mask, b_mask], [cotan])
out_test, dout_test = mx.vjp(f_test, [a, b, a_mask, b_mask], [cotan])
mx.eval((out_ref, dout_ref, out_test, dout_test))
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
def make_mask(tm_, tn_, batch, np_dtype):
arr_np_mask = np.random.normal(size=batch + (tm_, tn_)).astype(np_dtype)
arr_np_bool_mask = arr_np_mask < 0.0
arr_np_mask[arr_np_bool_mask] = 0.0
return mx.array(arr_np_bool_mask), mx.array(arr_np_mask)
def test_shape(
M,
N,
@@ -737,49 +785,49 @@ class TestBlas(mlx_tests.MLXTestCase):
batch_A=batch_A,
batch_B=batch_B,
):
tm = (M + block_size - 1) // block_size
tn = (N + block_size - 1) // block_size
tk = (K + block_size - 1) // block_size
batch_out = np.broadcast_shapes(batch_A, batch_B)
cotan = mx.ones(batch_out + (M, N))
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
batch_out = np.broadcast_shapes(batch_A, batch_B)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0
b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0
out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0
tm = (M + block_size - 1) // block_size
tn = (N + block_size - 1) // block_size
tk = (K + block_size - 1) // block_size
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
a_mx_bool_mask, a_mx_mask = make_mask(tm, tk, batch_A, np_dtype)
b_mx_bool_mask, b_mx_mask = make_mask(tk, tn, batch_B, np_dtype)
out_mx_bool_mask, out_mx_mask = make_mask(tm, tn, batch_out, np_dtype)
# Boolean block masks
run_test(
a_mx,
b_mx,
block_size,
out_mx_bool_mask,
a_mx_bool_mask,
b_mx_bool_mask,
cotan,
)
run_test(a_mx, b_mx, block_size, out_mx_bool_mask, None, None, cotan)
run_test(
a_mx, b_mx, block_size, None, a_mx_bool_mask, b_mx_bool_mask, cotan
)
if transpose:
b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype)
b_mx = mx.array(b_np)
b_np = np.swapaxes(b_np, -2, -1)
b_mx = mx.swapaxes(b_mx, -2, -1)
out_np = np_block_masked_mm(
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
# Float block masks
run_test(
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan
)
out_mx = mx.block_masked_mm(
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask
run_test(a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan)
run_test_mask_vjp(
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan
)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask)
out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
out_np = np_block_masked_mm(
a_np, b_np, block_size, None, a_np_mask, b_np_mask
run_test_mask_vjp(
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan
)
out_mx = mx.block_masked_mm(
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask
)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
shapes = (
(16, 16, 16, 32),
@@ -789,11 +837,10 @@ class TestBlas(mlx_tests.MLXTestCase):
)
for M, N, K, block_size in shapes:
test_shape(M, N, K, block_size, transpose=False)
test_shape(M, N, K, block_size, transpose=True)
test_shape(M, N, K, block_size)
# Test broadcasting
test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2))
test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))
# Test gemv
a_np = np.random.normal(size=(64, 64)).astype(np.float32)