mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
updates on 0.0.7
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# LoRA
|
||||
# Fine-Tuning with LoRA or QLoRA
|
||||
|
||||
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
|
||||
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
|
||||
task. The example alsos upport quantized LoRA (QLoRA).[^qlora]
|
||||
task. The example also supports quantized LoRA (QLoRA).[^qlora]
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
|
@@ -65,13 +65,13 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_group_size",
|
||||
"--q-group-size",
|
||||
help="Group size for quantization.",
|
||||
type=int,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_bits",
|
||||
"--q-bits",
|
||||
help="Bits per weight for quantization.",
|
||||
type=int,
|
||||
default=4,
|
||||
@@ -95,9 +95,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Load the torch model weights to numpy:
|
||||
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
del state
|
||||
weights = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||
for k, v in weights.items():
|
||||
weights[k] = v.to(torch.float16).numpy()
|
||||
|
||||
# Standardize the params
|
||||
with open(torch_path / "params.json", "r") as f:
|
||||
|
@@ -17,9 +17,7 @@ from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LoRA finetuning with Llama or Mistral"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
@@ -366,8 +364,6 @@ if __name__ == "__main__":
|
||||
for l in model.layers[-args.lora_layers :]:
|
||||
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
|
||||
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
|
||||
# TODO, don't need this if we get rid of stop grad in quantized linear
|
||||
l.attention.wo = LoRALinear.from_linear(l.attention.wo)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
|
@@ -5,29 +5,9 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
|
||||
|
||||
# TODO remove this for v0.0.7
|
||||
def patch_ql(self, x):
|
||||
x = mx.quantized_matmul(
|
||||
x.astype(mx.float16),
|
||||
self.weight.T,
|
||||
scales=self.scales,
|
||||
biases=self.biases,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
x = mx.stop_gradient(x)
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
nn.QuantizedLinear.__call__ = patch_ql
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
@@ -73,8 +53,7 @@ class LoRALinear(nn.Module):
|
||||
else:
|
||||
y = self.linear(x.astype(self.linear.weight.dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
out = y + 2.0 * z
|
||||
return out
|
||||
return y + 2.0 * z
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
@@ -1,3 +1,3 @@
|
||||
mlx
|
||||
mlx>=0.0.7
|
||||
sentencepiece
|
||||
torch
|
||||
|
Reference in New Issue
Block a user