From 8c8f9d6440bc2ab0c4707dfc191019bc9394d82e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 15 Dec 2023 10:42:18 -0800 Subject: [PATCH] keep base weights in fp16 --- lora/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora/models.py b/lora/models.py index b3f18113..e0bfa9f9 100644 --- a/lora/models.py +++ b/lora/models.py @@ -47,7 +47,7 @@ class LoRALinear(nn.Module): self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) def __call__(self, x): - y = self.linear(x.astype(self.linear.weight.dtype)).astype(x.dtype) + y = self.linear(x.astype(self.linear.weight.dtype)) z = (x @ self.lora_a) @ self.lora_b return y + 2.0 * z