From 8b05bb6d18e05104d43305e6dfe8f86eb718ec4d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Mar 2024 17:41:23 -0800 Subject: [PATCH] [mlx-lm] Use sdpa in llama / mistral model (#515) * use sdpa * update a few more models * version * fix stablelm type --- llms/mlx_lm/models/gemma.py | 16 +++++----------- llms/mlx_lm/models/llama.py | 15 ++++----------- llms/mlx_lm/models/qwen2.py | 15 ++++----------- llms/mlx_lm/models/stablelm.py | 18 +++++------------- llms/mlx_lm/models/starcoder2.py | 16 +++++----------- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/version.py | 2 +- 7 files changed, 25 insertions(+), 59 deletions(-) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 2a99c3c9..6b35f257 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -49,8 +49,6 @@ class Attention(nn.Module): self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.head_dim = head_dim = args.head_dim - self.repeats = n_heads // n_kv_heads - self.scale = head_dim**-0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) @@ -79,10 +77,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -93,11 +87,11 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 66105896..817e5cc1 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -43,8 +43,6 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = n_heads // n_kv_heads - head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 @@ -80,10 +78,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -94,11 +88,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 7cb5e106..3773cbfb 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -44,8 +44,6 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = n_heads // n_kv_heads - head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 @@ -81,10 +79,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -95,11 +89,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index f03051bc..73b01961 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -32,7 +32,6 @@ class Attention(nn.Module): self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads - self.repeats = self.num_heads // self.num_key_value_heads self.rope_theta = config.rope_theta self.partial_rotary_factor = config.partial_rotary_factor @@ -82,10 +81,6 @@ class Attention(nn.Module): B, L, self.num_key_value_heads, self.head_dim ).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - # Add RoPE to the queries and keys and combine them with the cache if cache is not None: key_cache, value_cache = cache @@ -102,14 +97,11 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(values_hat), (keys, values) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ).astype(values.dtype) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) class MLP(nn.Module): diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 582bdbcb..aefcf88c 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -31,8 +31,6 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = self.n_heads // self.n_kv_heads - head_dim = args.hidden_size // args.num_attention_heads self.scale = head_dim**-0.5 @@ -57,10 +55,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -71,11 +65,11 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 049049e7..518871ef 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.4 +mlx>=0.6 numpy transformers>=4.38.0 protobuf diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index b5f92ab1..87ee07a7 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.0.14" +__version__ = "0.1.0"