feat: move lora into mlx-lm (#337)

* feat: Add lora and qlora training to mlx-lm


---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen 2024-01-23 08:44:37 -08:00 committed by GitHub
parent 85c1ff8fd6
commit 362e88a744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 987 additions and 111 deletions

View File

@ -14,6 +14,10 @@ pip install mlx-lm
conda install -c conda-forge mlx-lm
```
The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details
on this see the [LoRA
documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md).
### Python API
You can use `mlx-lm` as a module:

165
llms/mlx_lm/LORA.md Normal file
View File

@ -0,0 +1,165 @@
# Fine-Tuning with LoRA or QLoRA
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Mistral
- Llama
- Phi2
- Mixtral
## Contents
* [Run](#Run)
* [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate)
* [Generate](#Generate)
* [Fuse and Upload](#Fuse-and-Upload)
* [Data](#Data)
* [Memory Issues](#Memory-Issues)
## Run
The main command is `mlx_lm.lora`. To see a full list of options run:
```shell
python -m mlx_lm.lora --help
```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted model.
### Fine-tune
To fine-tune a model use:
```shell
python -m mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
--iters 600
```
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
For example, to fine-tune a Mistral 7B you can use `--model
mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.npz>`.
### Evaluate
To compute test set perplexity use:
```shell
python -m mlx_lm.lora \
--model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--data <path_to_data> \
--test
```
## Fuse and Upload
You can generate a model fused with the low-rank adapters using the
`mlx_lm.fuse` command. This command also allows you to upload the fused model
to the Hugging Face Hub.
To see supported options run:
```shell
python -m mlx_lm.fuse --help
```
To generate the fused model run:
```shell
python -m mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters.npz`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```shell
python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--upload-repo mlx-community/my-4bit-lora-mistral \
--hf-path mistralai/Mistral-7B-v0.1
```
## Data
The LoRA command expects you to provide a dataset with `--data`. The MLX
Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
file should look like:
```
{"text": "This is an example for the model."}
```
Note, other keys will be ignored by the loader.
## Memory Issues
Fine-tuning a large model with LoRA requires a machine with a decent amount
of memory. Here are some tips to reduce memory use should you need to do so:
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
more details.
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
4. Longer examples require more memory. If it makes sense for your data, one thing
you can do is break your examples into smaller
sequences when making the `{train, valid, test}.jsonl` files.
For example, for a machine with 32 GB the following should run reasonably fast:
```
python lora.py \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4 \
--data wikisql
```
The above command on an M1 Max with 32 GB runs at about 250
tokens-per-second, using the MLX Example
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@ -3,5 +3,8 @@
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
For more information on this example, see the
[README](../README.md) in the parent directory.
For more information on this example, see the [README](../README.md) in the
parent directory.
This package also supports fine tuning with LoRA or QLoRA. For more information
see the [LoRA documentation](LORA.md).

View File

@ -2,17 +2,21 @@ import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Dict, Tuple
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import transformers
from mlx.utils import tree_flatten
from .utils import get_model_path, linear_class_predicate, load_model
MAX_FILE_SIZE_GB = 15
from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser:
return parser
def fetch_from_hub(
model_path: str,
) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]:
model_path = get_model_path(model_path)
model = load_model(model_path)
config = transformers.AutoConfig.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
return model, config.to_dict(), tokenizer
def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> tuple:
) -> Tuple:
"""
Applies quantization to the model weights.
@ -81,7 +72,7 @@ def quantize_model(
q_bits (int): Bits per weight for quantization.
Returns:
tuple: Tuple containing quantized weights and config.
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
@ -94,76 +85,6 @@ def quantize_model(
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
response = generate(model, tokenizer, prompt="hello", verbose=True)
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
@ -174,7 +95,8 @@ def convert(
upload_repo: str = None,
):
print("[INFO] Loading")
model, config, tokenizer = fetch_from_hub(hf_path)
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path)
weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
@ -185,12 +107,17 @@ def convert(
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
save_weights(mlx_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)

91
llms/mlx_lm/fuse.py Normal file
View File

@ -0,0 +1,91 @@
import argparse
import glob
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Union
from mlx.utils import tree_flatten, tree_unflatten
from .tuner.lora import LoRALinear
from .tuner.utils import apply_lora_layers
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Path to the trained adapter weights (npz or safetensors).",
)
parser.add_argument(
"--hf-path",
type=str,
default=None,
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser.parse_args()
def main() -> None:
print("Loading pretrained model")
args = parse_arguments()
model_path = get_model_path(args.model)
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_file)
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, LoRALinear)
]
model.update_modules(tree_unflatten(fused_linears))
weights = dict(tree_flatten(model.parameters()))
save_path = Path(args.save_path)
save_weights(save_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, save_path)
tokenizer.save_pretrained(save_path)
with open(save_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_repo is not None:
hf_path = args.hf_path or (
args.model if not Path(args.model).exists() else None
)
if hf_path is None:
raise ValueError(
"Must provide original Hugging Face repo to upload local model."
)
upload_to_hub(args.save_path, args.upload_repo, hf_path)
if __name__ == "__main__":
main()

252
llms/mlx_lm/lora.py Normal file
View File

@ -0,0 +1,252 @@
import argparse
import json
import math
from pathlib import Path
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
from .models import llama, mixtral, phi2
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train
from .utils import generate, load
SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model]
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
# Generation args
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="The maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument(
"--prompt",
"-p",
type=str,
help="The prompt for generation",
default=None,
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
)
parser.add_argument(
"--data",
type=str,
default="data/",
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
"--lora-layers",
type=int,
default=16,
help="Number of layers to fine-tune",
)
parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.")
parser.add_argument(
"--iters", type=int, default=1000, help="Iterations to train for."
)
parser.add_argument(
"--val-batches",
type=int,
default=25,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument(
"--learning-rate", type=float, default=1e-5, help="Adam learning rate."
)
parser.add_argument(
"--steps-per-report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Save/load path for the trained adapter weights.",
)
parser.add_argument(
"--save-every",
type=int,
default=100,
help="Save the model every N iterations.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
)
parser.add_argument(
"--test-batches",
type=int,
default=500,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
return parser
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path, key: str = "text"):
if not path.exists():
self._data = None
else:
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
self._key = key
def __getitem__(self, idx: int):
return self._data[idx][self._key]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
def load_dataset(args):
names = ("train", "valid", "test")
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
if model.__class__ not in SUPPORTED_MODELS:
raise ValueError(
f"Model {model.__class__} not supported. "
f"Supported models: { SUPPORTED_MODELS}"
)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args)
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
# init training args
trainingArgs = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=args.adapter_file,
)
if args.train:
print("Training")
model.train()
opt = optim.Adam(learning_rate=args.learning_rate)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=trainingArgs,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
)
# Load the LoRA adapter weights which we assume should exist by this point
if not Path(args.adapter_file).is_file():
raise ValueError(
f"Adapter file {args.adapter_file} missing. "
"Use --train to learn and save the adapters.npz."
)
model.load_weights(args.adapter_file, strict=False)
if args.test:
print("Testing")
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
print("Generating")
generate(
model=model,
tokenizer=tokenizer,
temp=args.temp,
max_tokens=args.max_tokens,
prompt=args.prompt,
verbose=True,
)

View File

@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@ -158,12 +159,28 @@ class MixtralSparseMoeBlock(nn.Module):
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
inds = mx.stop_gradient(
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
) # TODO remove it once we figure out how to fine tune TopK in MOE
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
@ -229,7 +246,7 @@ class MixtralModel(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h[:, T - 1 : T, :]), cache
return self.norm(h), cache
class Model(nn.Module):

View File

88
llms/mlx_lm/tuner/lora.py Normal file
View File

@ -0,0 +1,88 @@
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
)
lora_lin.linear = linear
return lora_lin
def to_linear(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
rank: int = 8,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, rank),
)
self.lora_b = mx.zeros(shape=(rank, output_dims))
def __call__(self, x):
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + self.scale * z

View File

@ -0,0 +1,204 @@
import os
import time
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
@dataclass
class TrainingArgs:
lora_layers: int = field(
default=16, metadata={"help": "Number of layers to fine-tune"}
)
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
default=25,
metadata={
"help": "Number of validation batches, -1 uses the entire validation set."
},
)
steps_per_report: int = field(
default=10,
metadata={"help": "Number of training steps between loss reporting."},
)
steps_per_eval: int = field(
default=200, metadata={"help": "Number of training steps between validations."}
)
steps_per_save: int = field(
default=100, metadata={"help": "Save the model every number steps"}
)
max_seq_length: int = field(
default=2048, metadata={"help": "Maximum sequence length."}
)
adapter_file: str = field(
default="adapter.npz",
metadata={"help": "Save/load path for the trained adapter weights."},
)
def default_loss(model, inputs, targets, lengths):
logits, _ = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
while True:
# Shuffle indices
indices = np.arange(len(dataset))
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = [
tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than max_seq_length
if max(lengths) > max_seq_length:
print(
"[WARNING] Some sequences are longer than 2048 tokens. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train:
break
def evaluate(
model,
dataset,
tokenizer,
batch_size,
num_batches,
max_seq_length=2048,
loss: callable = default_loss,
):
all_losses = []
ntokens = 0
for it, batch in zip(
range(num_batches),
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item())
ntokens += toks.item()
return np.sum(all_losses) / ntokens
def train(
model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss,
):
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = []
n_tokens = 0
print("Starting training..., iters:", args.iters)
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(args.iters),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# Model update
optimizer.update(model, grad)
mx.eval(model.parameters(), optimizer.state, lvalue)
# Record loss
losses.append(lvalue.item())
n_tokens += toks.item()
# Report training loss if needed
if (it + 1) % args.steps_per_report == 0:
train_loss = np.mean(losses)
stop = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
)
losses = []
n_tokens = 0
start = time.perf_counter()
# Report validation loss if needed
if it == 0 or (it + 1) % args.steps_per_eval == 0:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
)
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val took {(time.perf_counter() - stop):.3f}s"
)
start = time.perf_counter()
# Save adapter weights if needed
if (it + 1) % args.steps_per_save == 0:
save_adapter(model=model, adapter_file=args.adapter_file)
print(
f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
)
# save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file)
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
def save_adapter(
model: nn.Module,
adapter_file: str,
):
flattened_tree = tree_flatten(model.trainable_parameters())
mx.savez(adapter_file, **dict(flattened_tree))

View File

@ -0,0 +1,22 @@
import mlx.core as mx
from mlx.utils import tree_unflatten
from .lora import LoRALinear
def apply_lora_layers(model, adapter_file: str):
adapters = list(mx.load(adapter_file).items())
linear_replacements = {}
lora_layers = set(
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
)
for name, module in model.named_modules():
if name in lora_layers:
replacement_module = LoRALinear.from_linear(module)
linear_replacements[name] = replacement_module
model.update_modules(tree_unflatten(list(linear_replacements.items())))
model.update(tree_unflatten(adapters))
return model

View File

@ -1,13 +1,14 @@
import copy
import glob
import json
import logging
from pathlib import Path
from typing import Generator, Tuple
from typing import Any, Dict, Generator, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models import llama, mixtral, phi2, plamo, qwen
@ -21,6 +22,7 @@ MODEL_MAPPING = {
"qwen": qwen,
"plamo": plamo,
}
MAX_FILE_SIZE_GB = 15
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module:
mx.eval(model.parameters())
model.eval()
return model
@ -211,7 +214,7 @@ def load(
path_or_hf_repo: str, tokenizer_config={}
) -> Tuple[nn.Module, PreTrainedTokenizer]:
"""
Load the model from a given path or a huggingface repository.
Load the model and tokenizer from a given path or a huggingface repository.
Args:
model_path (Path): The path or the huggingface repository to load the model from.
@ -229,3 +232,103 @@ def load(
model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer
def fetch_from_hub(
model_path: Path,
) -> Tuple[Dict, dict, PreTrainedTokenizer]:
model = load_model(model_path)
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, config.to_dict(), tokenizer
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
response = generate(model, tokenizer, prompt="hello", verbose=True)
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
)
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
mx.save_safetensors(str(save_path / shard_name), shard)

View File

@ -4,7 +4,7 @@ import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -198,7 +198,7 @@ class Model(nn.Module):
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])