diff --git a/lora/models.py b/lora/models.py index b3f18113..e0bfa9f9 100644 --- a/lora/models.py +++ b/lora/models.py @@ -47,7 +47,7 @@ class LoRALinear(nn.Module): self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) def __call__(self, x): - y = self.linear(x.astype(self.linear.weight.dtype)).astype(x.dtype) + y = self.linear(x.astype(self.linear.weight.dtype)) z = (x @ self.lora_a) @ self.lora_b return y + 2.0 * z