No copy gems (#801)

* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
This commit is contained in:
Jagrit Digani
2024-03-12 13:13:41 -07:00
committed by GitHub
parent d0c544a868
commit 5ad133f8bb
12 changed files with 799 additions and 448 deletions

View File

@@ -393,6 +393,77 @@ class TestBlas(mlx_tests.MLXTestCase):
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
)
def test_matrix_vector_attn(self):
# Multi-query style attention check
for dtype in self.dtypes:
# fmt: off
for (B, D, n_kv_heads, factor, qsl, ksl) in (
(1, 16, 8, 4, 1, 256),
(1, 16, 8, 4, 32, 256),
(1, 16, 8, 4, 256, 1),
(4, 16, 8, 4, 1, 256),
(4, 16, 8, 4, 256, 1),
):
# fmt: on
with self.subTest(
B=B, # Batch size
D=D, # Dimension of mm
n_kv_heads=n_kv_heads, # key-value heads
factor=factor, # factor to get query heads
qsl=qsl, # Query sequence length
ksl=ksl, # Key sequence length
dtype=dtype # Data type
):
np_dtype = getattr(np, dtype)
# Fix shapes for kqv
n_q_heads = n_kv_heads * factor
Dk = D * n_kv_heads
Dq = D * n_q_heads
scale = 1. / math.sqrt(Dk)
shape_queries = (B, qsl, Dq)
shape_keys = (B, ksl, Dk)
shape_values = (B, ksl, Dk)
# Prepare numpy arrays
q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype)
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
# Rearrange to move heads up
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
# Do attn style matmul
s_np = q_np_reshape @ k_np_reshape
o_np = s_np @ v_np_reshape
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
# Test mlx
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
# Rearrange to move heads up
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
# Do attn style matmul
s_mx = q_mx_reshape @ k_mx_reshape
o_mx = (s_mx @ v_mx_reshape)
o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
# Check against np
self.assertListEqual(list(s_np.shape), list(s_mx.shape))
self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4))
self.assertListEqual(list(o_np.shape), list(o_mx.shape))
self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4))
def test_matrix_vector_edgecases(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
@@ -503,16 +574,29 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in (
(1,),
(32, 128),
):
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
@@ -564,16 +648,12 @@ class TestBlas(mlx_tests.MLXTestCase):
out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[
cotan,
],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[
cotan,
],
[cotan],
)
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())