diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 6f72dd6e..438278e5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -191,7 +191,7 @@ class Attention(nn.Module): keys = self.rope(keys) output = scaled_dot_product_attention( - queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 468ffb43..fac59d78 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -90,7 +90,7 @@ class Attention(nn.Module): keys = self.rope(keys) output = scaled_dot_product_attention( - queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output)