[mlx-lm] Use sdpa in llama / mistral model (#515)

* use sdpa

* update a few more models

* version

* fix stablelm type
This commit is contained in:
Awni Hannun 2024-03-07 17:41:23 -08:00 committed by GitHub
parent 7cdd1b69ac
commit 8b05bb6d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 25 additions and 59 deletions

View File

@ -49,8 +49,6 @@ class Attention(nn.Module):
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim self.head_dim = head_dim = args.head_dim
self.repeats = n_heads // n_kv_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
@ -79,10 +77,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -93,11 +87,11 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)

View File

@ -43,8 +43,6 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
@ -80,10 +78,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -94,11 +88,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)

View File

@ -44,8 +44,6 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
@ -81,10 +79,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -95,11 +89,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)

View File

@ -32,7 +32,6 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.repeats = self.num_heads // self.num_key_value_heads
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor self.partial_rotary_factor = config.partial_rotary_factor
@ -82,10 +81,6 @@ class Attention(nn.Module):
B, L, self.num_key_value_heads, self.head_dim B, L, self.num_key_value_heads, self.head_dim
).transpose(0, 2, 1, 3) ).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
# Add RoPE to the queries and keys and combine them with the cache # Add RoPE to the queries and keys and combine them with the cache
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
@ -102,14 +97,11 @@ class Attention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=scale, mask=mask
scores = scores + mask ).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
scores = mx.softmax(scores, axis=-1).astype(values.dtype) return self.o_proj(output), (keys, values)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(values_hat), (keys, values)
class MLP(nn.Module): class MLP(nn.Module):

View File

@ -31,8 +31,6 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
head_dim = args.hidden_size // args.num_attention_heads head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
@ -57,10 +55,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -71,11 +65,11 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)

View File

@ -1,4 +1,4 @@
mlx>=0.4 mlx>=0.6
numpy numpy
transformers>=4.38.0 transformers>=4.38.0
protobuf protobuf

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.0.14" __version__ = "0.1.0"