updates on 0.0.7

This commit is contained in:
Awni Hannun
2024-01-03 13:38:45 -08:00
parent 837fcc2097
commit e81cab43e4
5 changed files with 10 additions and 35 deletions

View File

@@ -1,8 +1,8 @@
# LoRA
# Fine-Tuning with LoRA or QLoRA
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task. The example alsos upport quantized LoRA (QLoRA).[^qlora]
task. The example also supports quantized LoRA (QLoRA).[^qlora]
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to

View File

@@ -65,13 +65,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
@@ -95,9 +95,9 @@ if __name__ == "__main__":
)
# Load the torch model weights to numpy:
state = torch.load(str(torch_path / "consolidated.00.pth"))
weights = {k: v.to(torch.float16).numpy() for k, v in state.items()}
del state
weights = torch.load(str(torch_path / "consolidated.00.pth"))
for k, v in weights.items():
weights[k] = v.to(torch.float16).numpy()
# Standardize the params
with open(torch_path / "params.json", "r") as f:

View File

@@ -17,9 +17,7 @@ from sentencepiece import SentencePieceProcessor
def build_parser():
parser = argparse.ArgumentParser(
description="LoRA finetuning with Llama or Mistral"
)
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
@@ -366,8 +364,6 @@ if __name__ == "__main__":
for l in model.layers[-args.lora_layers :]:
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
# TODO, don't need this if we get rid of stop grad in quantized linear
l.attention.wo = LoRALinear.from_linear(l.attention.wo)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")

View File

@@ -5,29 +5,9 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
# TODO remove this for v0.0.7
def patch_ql(self, x):
x = mx.quantized_matmul(
x.astype(mx.float16),
self.weight.T,
scales=self.scales,
biases=self.biases,
group_size=self.group_size,
bits=self.bits,
)
x = mx.stop_gradient(x)
if "bias" in self:
x = x + self.bias
return x
nn.QuantizedLinear.__call__ = patch_ql
@dataclass
class ModelArgs:
dim: int
@@ -73,8 +53,7 @@ class LoRALinear(nn.Module):
else:
y = self.linear(x.astype(self.linear.weight.dtype))
z = (x @ self.lora_a) @ self.lora_b
out = y + 2.0 * z
return out
return y + 2.0 * z
class RMSNorm(nn.Module):

View File

@@ -1,3 +1,3 @@
mlx
mlx>=0.0.7
sentencepiece
torch