diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index 45976f33..e60b60d5 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -107,12 +107,9 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 9b9a602a..39456d97 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -73,11 +73,8 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index b1f14706..8a884817 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -93,11 +93,8 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index f44a94e7..f9f96525 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -93,12 +93,9 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index fbd4c7a3..5b4875eb 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -95,12 +95,9 @@ class MixtralAttention(nn.Module): 0, 2, 1, 3 ) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 93bba876..ce8c226d 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -86,12 +86,9 @@ class PhiAttention(nn.Module): B, L, self.num_key_value_heads, self.head_dim ).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 42abe688..f3f868ad 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -93,12 +93,9 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) if cache is not None: key_cache, value_cache = cache diff --git a/llms/mlx_lm/models/stablelm_epoch.py b/llms/mlx_lm/models/stablelm_epoch.py index a0fe0d30..2d492295 100644 --- a/llms/mlx_lm/models/stablelm_epoch.py +++ b/llms/mlx_lm/models/stablelm_epoch.py @@ -87,12 +87,9 @@ class Attention(nn.Module): B, L, self.num_key_value_heads, self.head_dim ).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) + keys = mx.repeat(keys, self.repeats, axis=1) + values = mx.repeat(values, self.repeats, axis=1) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: