From bbfcc103d7b30c0ea9f1a1b6cdf6162d970c40f8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 24 Mar 2024 19:34:51 -0700 Subject: [PATCH] cast around lora adapters (#613) --- llms/mlx_lm/tuner/lora.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index 2ad0656a..76894509 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -97,9 +97,6 @@ class LoRALinear(nn.Module): self.lora_b = mx.zeros(shape=(r, output_dims)) def __call__(self, x): - dtype = self.linear.weight.dtype - if isinstance(self.linear, nn.QuantizedLinear): - dtype = self.linear.scales.dtype - y = self.linear(x.astype(dtype)) + y = self.linear(x) z = (self.dropout(x) @ self.lora_a) @ self.lora_b - return y + self.scale * z + return y + (self.scale * z).astype(x.dtype)