diff --git a/lora/lora.py b/lora/lora.py index 997b14cb..2e0fa0a1 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -184,6 +184,7 @@ def load(args): def loss(model, inputs, targets, lengths): # Run model on inputs logits, _ = model(inputs) + logits = logits.astype(mx.float32) # Mask padding tokens length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] @@ -326,7 +327,7 @@ def generate(model, prompt, tokenizer, args): print(s, flush=True) -def load_model(folder: str, dtype=mx.float32): +def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) with open(model_path / "params.json", "r") as f: diff --git a/lora/models.py b/lora/models.py index 52024531..b3f18113 100644 --- a/lora/models.py +++ b/lora/models.py @@ -47,7 +47,7 @@ class LoRALinear(nn.Module): self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) def __call__(self, x): - y = self.linear(x) + y = self.linear(x.astype(self.linear.weight.dtype)).astype(x.dtype) z = (x @ self.lora_a) @ self.lora_b return y + 2.0 * z