use sdpa and exportable functions in transformer multi head attention (#1760)

This commit is contained in:
Awni Hannun 2025-01-09 13:11:55 -08:00 committed by GitHub
parent c7b0300af5
commit 657f466402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 14 deletions

View File

@ -82,21 +82,15 @@ class MultiHeadAttention(Module):
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
queries = mx.unflatten(queries, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
keys = mx.unflatten(keys, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
values = mx.unflatten(values, -1, (num_heads, -1)).transpose(0, 2, 1, 3)
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
return self.out_proj(output)
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):

View File

@ -1835,6 +1835,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertFalse(mx.any(mx.isnan(mask)))
self.assertTrue(mask[0, -1].item() < 0)
def test_attention(self):
attn = nn.MultiHeadAttention(32, 4)
x = mx.random.normal(shape=(2, 5, 32))
out = attn(x, x, x)
self.assertEqual(out.shape, x.shape)
if __name__ == "__main__":
unittest.main()