From 1d53354b510194f189fff21f9470c0b1ae6f5ec2 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 31 Oct 2024 12:06:34 -0700 Subject: [PATCH] fix sed --- llms/mlx_lm/models/llama.py | 2 +- llms/mlx_lm/models/qwen2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 6f72dd6e..438278e5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -191,7 +191,7 @@ class Attention(nn.Module): keys = self.rope(keys) output = scaled_dot_product_attention( - queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 468ffb43..fac59d78 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -90,7 +90,7 @@ class Attention(nn.Module): keys = self.rope(keys) output = scaled_dot_product_attention( - queries, keys, values, cache=cache, cache=cache, scale=self.scale, mask=mask + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output)