mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
No copy gems (#801)
* Enable collapsing batch dims in gemm * Update gemm to only make copies when neither of the last 2 axes are contiguous * Update addmm to support gemv shapes * Update addmm to support irregular batch strides * Update tests
This commit is contained in:
@@ -393,6 +393,77 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||
)
|
||||
|
||||
def test_matrix_vector_attn(self):
|
||||
# Multi-query style attention check
|
||||
for dtype in self.dtypes:
|
||||
# fmt: off
|
||||
for (B, D, n_kv_heads, factor, qsl, ksl) in (
|
||||
(1, 16, 8, 4, 1, 256),
|
||||
(1, 16, 8, 4, 32, 256),
|
||||
(1, 16, 8, 4, 256, 1),
|
||||
(4, 16, 8, 4, 1, 256),
|
||||
(4, 16, 8, 4, 256, 1),
|
||||
):
|
||||
# fmt: on
|
||||
with self.subTest(
|
||||
B=B, # Batch size
|
||||
D=D, # Dimension of mm
|
||||
n_kv_heads=n_kv_heads, # key-value heads
|
||||
factor=factor, # factor to get query heads
|
||||
qsl=qsl, # Query sequence length
|
||||
ksl=ksl, # Key sequence length
|
||||
dtype=dtype # Data type
|
||||
):
|
||||
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Fix shapes for kqv
|
||||
n_q_heads = n_kv_heads * factor
|
||||
Dk = D * n_kv_heads
|
||||
Dq = D * n_q_heads
|
||||
scale = 1. / math.sqrt(Dk)
|
||||
|
||||
shape_queries = (B, qsl, Dq)
|
||||
shape_keys = (B, ksl, Dk)
|
||||
shape_values = (B, ksl, Dk)
|
||||
|
||||
# Prepare numpy arrays
|
||||
q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype)
|
||||
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
|
||||
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
|
||||
|
||||
# Rearrange to move heads up
|
||||
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
|
||||
# Do attn style matmul
|
||||
s_np = q_np_reshape @ k_np_reshape
|
||||
o_np = s_np @ v_np_reshape
|
||||
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||
|
||||
# Test mlx
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
# Rearrange to move heads up
|
||||
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
|
||||
# Do attn style matmul
|
||||
s_mx = q_mx_reshape @ k_mx_reshape
|
||||
o_mx = (s_mx @ v_mx_reshape)
|
||||
o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||
|
||||
# Check against np
|
||||
self.assertListEqual(list(s_np.shape), list(s_mx.shape))
|
||||
self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4))
|
||||
|
||||
self.assertListEqual(list(o_np.shape), list(o_mx.shape))
|
||||
self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4))
|
||||
|
||||
def test_matrix_vector_edgecases(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
@@ -503,16 +574,29 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in (
|
||||
(1,),
|
||||
(32, 128),
|
||||
):
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
@@ -564,16 +648,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
Reference in New Issue
Block a user