Metal shaders for memory efficient self attention on large sequences (#964)

* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
Brian Keene
2024-06-03 12:16:19 -04:00
committed by GitHub
parent 3576b547c5
commit 1865299a30
7 changed files with 1244 additions and 9 deletions

View File

@@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) {
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");

View File

@@ -32,9 +32,80 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
return mlx_primitives_sdpa(q, k, v, scale)
class TestFastSDPA(mlx_tests.MLXTestCase):
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)
R = 20
L = R
Dk = 64
H = 3
scale = float(1.0 / np.sqrt(Dk))
q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32)
k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)
reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
dtypes = [np.float32]
Dk = 64
if self.is_apple_silicon:
dtypes.append(np.half)
for SEQUENCE_LENGTH in [63, 129, 400]:
for DTYPE in dtypes:
B = 2
H = 24
n_kv_heads = H
q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype(
DTYPE
)
k_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)
v_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)
q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
rtol = 1e-3
atol = 1e-2
if SEQUENCE_LENGTH > 500:
rtol = 1e-2
if DTYPE == np.half:
rtol = 1e-2
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)