mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences Updated fast attention: GEMM-ified with Steel primitives Uses flash attention 1 for scale correction * more compiler silencing * Address rebase issues * Templatize kernel instantiation, revise cpu bindings * Safer writes to output * Permit batch size > 1 * Numerical fixes for sdpa self attention * Re-enable test, remove unused variable * add benchmarking script * Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
@@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) {
|
||||
v (array): Input values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
@@ -32,9 +32,80 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
return mlx_primitives_sdpa(q, k, v, scale)
|
||||
|
||||
|
||||
class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa(self):
|
||||
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
np.random.seed(0)
|
||||
R = 20
|
||||
L = R
|
||||
Dk = 64
|
||||
H = 3
|
||||
scale = float(1.0 / np.sqrt(Dk))
|
||||
q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32)
|
||||
k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
|
||||
v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
|
||||
|
||||
q_mlx = mx.array(q_npy)
|
||||
k_mlx = mx.array(k_npy)
|
||||
v_mlx = mx.array(v_npy)
|
||||
|
||||
reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
|
||||
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
|
||||
|
||||
dtypes = [np.float32]
|
||||
|
||||
Dk = 64
|
||||
|
||||
if self.is_apple_silicon:
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [63, 129, 400]:
|
||||
for DTYPE in dtypes:
|
||||
B = 2
|
||||
H = 24
|
||||
n_kv_heads = H
|
||||
q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype(
|
||||
DTYPE
|
||||
)
|
||||
k_npy = np.random.normal(
|
||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
||||
).astype(DTYPE)
|
||||
v_npy = np.random.normal(
|
||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
||||
).astype(DTYPE)
|
||||
|
||||
q_mlx = mx.array(q_npy)
|
||||
k_mlx = mx.array(k_npy)
|
||||
v_mlx = mx.array(v_npy)
|
||||
|
||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
rtol = 1e-3
|
||||
atol = 1e-2
|
||||
|
||||
if SEQUENCE_LENGTH > 500:
|
||||
rtol = 1e-2
|
||||
|
||||
if DTYPE == np.half:
|
||||
rtol = 1e-2
|
||||
|
||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
||||
|
||||
|
||||
class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa(self):
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
np.random.seed(0)
|
||||
|
Reference in New Issue
Block a user