diff --git a/llms/README.md b/llms/README.md index 3787fc94..b2af7fb7 100644 --- a/llms/README.md +++ b/llms/README.md @@ -14,6 +14,10 @@ pip install mlx-lm conda install -c conda-forge mlx-lm ``` +The `mlx-lm` package also supports LoRA and QLoRA fine-tuning. For more details +on this see the [LoRA +documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md). + ### Python API You can use `mlx-lm` as a module: diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md new file mode 100644 index 00000000..b8377f88 --- /dev/null +++ b/llms/mlx_lm/LORA.md @@ -0,0 +1,165 @@ +# Fine-Tuning with LoRA or QLoRA + +You can use use the `mlx-lm` package to fine-tune an LLM with low rank +adaptation (LoRA) for a target task.[^lora] The example also supports quantized +LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: + +- Mistral +- Llama +- Phi2 +- Mixtral + +## Contents + +* [Run](#Run) + * [Fine-tune](#Fine-tune) + * [Evaluate](#Evaluate) + * [Generate](#Generate) +* [Fuse and Upload](#Fuse-and-Upload) +* [Data](#Data) +* [Memory Issues](#Memory-Issues) + +## Run + +The main command is `mlx_lm.lora`. To see a full list of options run: + +```shell +python -m mlx_lm.lora --help +``` + +Note, in the following the `--model` argument can be any compatible Hugging +Face repo or a local path to a converted model. + +### Fine-tune + +To fine-tune a model use: + +```shell +python -m mlx_lm.lora \ + --model \ + --train \ + --data \ + --iters 600 +``` + +The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` +when using `--train` and a path to a `test.jsonl` when using `--test`. For more +details on the data format see the section on [Data](#Data). + +For example, to fine-tune a Mistral 7B you can use `--model +mistralai/Mistral-7B-v0.1`. + +If `--model` points to a quantized model, then the training will use QLoRA, +otherwise it will use regular LoRA. + +By default, the adapter weights are saved in `adapters.npz`. You can specify +the output location with `--adapter-file`. + +You can resume fine-tuning with an existing adapter with +`--resume-adapter-file `. + +### Evaluate + +To compute test set perplexity use: + +```shell +python -m mlx_lm.lora \ + --model \ + --adapter-file \ + --data \ + --test +``` + +## Fuse and Upload + +You can generate a model fused with the low-rank adapters using the +`mlx_lm.fuse` command. This command also allows you to upload the fused model +to the Hugging Face Hub. + +To see supported options run: + +```shell +python -m mlx_lm.fuse --help +``` + +To generate the fused model run: + +```shell +python -m mlx_lm.fuse --model +``` + +This will by default load the adapters from `adapters.npz`, and save the fused +model in the path `lora_fused_model/`. All of these are configurable. + +To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments +to `mlx_lm.fuse`. The latter is the repo name of the original model, which is +useful for the sake of attribution and model versioning. + +For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: + +```shell +python -m mlx_lm.fuse \ + --model mistralai/Mistral-7B-v0.1 \ + --upload-repo mlx-community/my-4bit-lora-mistral \ + --hf-path mistralai/Mistral-7B-v0.1 +``` + +## Data + +The LoRA command expects you to provide a dataset with `--data`. The MLX +Examples GitHub repo has an [example of the WikiSQL +data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the +correct format. + +For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a +`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data +loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl` +file should look like: + +``` +{"text": "This is an example for the model."} +``` + +Note, other keys will be ignored by the loader. + +## Memory Issues + +Fine-tuning a large model with LoRA requires a machine with a decent amount +of memory. Here are some tips to reduce memory use should you need to do so: + +1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model + with `convert.py` and the `-q` flag. See the [Setup](#setup) section for + more details. + +2. Try using a smaller batch size with `--batch-size`. The default is `4` so + setting this to `2` or `1` will reduce memory consumption. This may slow + things down a little, but will also reduce the memory use. + +3. Reduce the number of layers to fine-tune with `--lora-layers`. The default + is `16`, so you can try `8` or `4`. This reduces the amount of memory + needed for back propagation. It may also reduce the quality of the + fine-tuned model if you are fine-tuning with a lot of data. + +4. Longer examples require more memory. If it makes sense for your data, one thing + you can do is break your examples into smaller + sequences when making the `{train, valid, test}.jsonl` files. + +For example, for a machine with 32 GB the following should run reasonably fast: + +``` +python lora.py \ + --model mistralai/Mistral-7B-v0.1 \ + --train \ + --batch-size 1 \ + --lora-layers 4 \ + --data wikisql +``` + +The above command on an M1 Max with 32 GB runs at about 250 +tokens-per-second, using the MLX Example +[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) +data set. + + +[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. +[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/llms/mlx_lm/README.md b/llms/mlx_lm/README.md index f7fd1dad..66f2b5e9 100644 --- a/llms/mlx_lm/README.md +++ b/llms/mlx_lm/README.md @@ -3,5 +3,8 @@ This an example of large language model text generation that can pull models from the Hugging Face Hub. -For more information on this example, see the -[README](../README.md) in the parent directory. +For more information on this example, see the [README](../README.md) in the +parent directory. + +This package also supports fine tuning with LoRA or QLoRA. For more information +see the [LoRA documentation](LORA.md). diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index ead4f0e4..3f590f1c 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -2,17 +2,21 @@ import argparse import copy import glob import json +import shutil from pathlib import Path -from typing import Dict, Tuple +from typing import Tuple import mlx.core as mx import mlx.nn as nn -import transformers from mlx.utils import tree_flatten -from .utils import get_model_path, linear_class_predicate, load_model - -MAX_FILE_SIZE_GB = 15 +from .utils import ( + fetch_from_hub, + get_model_path, + linear_class_predicate, + save_weights, + upload_to_hub, +) def configure_parser() -> argparse.ArgumentParser: @@ -55,22 +59,9 @@ def configure_parser() -> argparse.ArgumentParser: return parser -def fetch_from_hub( - model_path: str, -) -> Tuple[Dict, dict, transformers.PreTrainedTokenizer]: - model_path = get_model_path(model_path) - - model = load_model(model_path) - - config = transformers.AutoConfig.from_pretrained(model_path) - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - - return model, config.to_dict(), tokenizer - - def quantize_model( model: nn.Module, config: dict, q_group_size: int, q_bits: int -) -> tuple: +) -> Tuple: """ Applies quantization to the model weights. @@ -81,7 +72,7 @@ def quantize_model( q_bits (int): Bits per weight for quantization. Returns: - tuple: Tuple containing quantized weights and config. + Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) @@ -94,76 +85,6 @@ def quantize_model( return quantized_weights, quantized_config -def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: - """ - Splits the weights into smaller shards. - - Args: - weights (dict): Model weights. - max_file_size_gb (int): Maximum size of each shard in gigabytes. - - Returns: - list: List of weight shards. - """ - max_file_size_bytes = max_file_size_gb << 30 - shards = [] - shard, shard_size = {}, 0 - for k, v in weights.items(): - estimated_size = v.size * v.dtype.size - if shard_size + estimated_size > max_file_size_bytes: - shards.append(shard) - shard, shard_size = {}, 0 - shard[k] = v - shard_size += estimated_size - shards.append(shard) - return shards - - -def upload_to_hub(path: str, upload_repo: str, hf_path: str): - """ - Uploads the model to Hugging Face hub. - - Args: - path (str): Local path to the model. - upload_repo (str): Name of the HF repo to upload to. - hf_path (str): Path to the original Hugging Face model. - """ - import os - - from huggingface_hub import HfApi, ModelCard, logging - - card = ModelCard.load(hf_path) - card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] - card.text = f""" -# {upload_repo} -This model was converted to MLX format from [`{hf_path}`](). -Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. -## Use with mlx - -```bash -pip install mlx-lm -``` - -```python -from mlx_lm import load, generate - -model, tokenizer = load("{upload_repo}") -response = generate(model, tokenizer, prompt="hello", verbose=True) -``` -""" - card.save(os.path.join(path, "README.md")) - - logging.set_verbosity_info() - - api = HfApi() - api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_folder( - folder_path=path, - repo_id=upload_repo, - repo_type="model", - ) - - def convert( hf_path: str, mlx_path: str = "mlx_model", @@ -174,7 +95,8 @@ def convert( upload_repo: str = None, ): print("[INFO] Loading") - model, config, tokenizer = fetch_from_hub(hf_path) + model_path = get_model_path(hf_path) + model, config, tokenizer = fetch_from_hub(model_path) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) @@ -185,12 +107,17 @@ def convert( model.load_weights(list(weights.items())) weights, config = quantize_model(model, config, q_group_size, q_bits) - mlx_path = Path(mlx_path) - mlx_path.mkdir(parents=True, exist_ok=True) - shards = make_shards(weights) - for i, shard in enumerate(shards): - mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard) + if isinstance(mlx_path, str): + mlx_path = Path(mlx_path) + + save_weights(mlx_path, weights) + + py_files = glob.glob(str(model_path / "*.py")) + for file in py_files: + shutil.copy(file, mlx_path) + tokenizer.save_pretrained(mlx_path) + with open(mlx_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py new file mode 100644 index 00000000..60a13f94 --- /dev/null +++ b/llms/mlx_lm/fuse.py @@ -0,0 +1,91 @@ +import argparse +import glob +import json +import shutil +from pathlib import Path +from typing import Any, Dict, Union + +from mlx.utils import tree_flatten, tree_unflatten + +from .tuner.lora import LoRALinear +from .tuner.utils import apply_lora_layers +from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") + parser.add_argument( + "--model", + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--save-path", + default="lora_fused_model", + help="The path to save the fused model.", + ) + parser.add_argument( + "--adapter-file", + type=str, + default="adapters.npz", + help="Path to the trained adapter weights (npz or safetensors).", + ) + parser.add_argument( + "--hf-path", + type=str, + default=None, + help="Path to the original Hugging Face model. Required for upload if --model is a local directory.", + ) + parser.add_argument( + "--upload-repo", + help="The Hugging Face repo to upload the model to.", + type=str, + default=None, + ) + return parser.parse_args() + + +def main() -> None: + print("Loading pretrained model") + args = parse_arguments() + + model_path = get_model_path(args.model) + model, config, tokenizer = fetch_from_hub(model_path) + + model.freeze() + model = apply_lora_layers(model, args.adapter_file) + fused_linears = [ + (n, m.to_linear()) + for n, m in model.named_modules() + if isinstance(m, LoRALinear) + ] + + model.update_modules(tree_unflatten(fused_linears)) + weights = dict(tree_flatten(model.parameters())) + + save_path = Path(args.save_path) + + save_weights(save_path, weights) + + py_files = glob.glob(str(model_path / "*.py")) + for file in py_files: + shutil.copy(file, save_path) + + tokenizer.save_pretrained(save_path) + + with open(save_path / "config.json", "w") as fid: + json.dump(config, fid, indent=4) + + if args.upload_repo is not None: + hf_path = args.hf_path or ( + args.model if not Path(args.model).exists() else None + ) + if hf_path is None: + raise ValueError( + "Must provide original Hugging Face repo to upload local model." + ) + upload_to_hub(args.save_path, args.upload_repo, hf_path) + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py new file mode 100644 index 00000000..2bcb8099 --- /dev/null +++ b/llms/mlx_lm/lora.py @@ -0,0 +1,252 @@ +import argparse +import json +import math +from pathlib import Path + +import mlx.optimizers as optim +import numpy as np +from mlx.utils import tree_flatten + +from .models import llama, mixtral, phi2 +from .tuner.lora import LoRALinear +from .tuner.trainer import TrainingArgs, evaluate, train +from .utils import generate, load + +SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model] + + +def build_parser(): + parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") + parser.add_argument( + "--model", + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + # Generation args + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="The maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", type=float, default=0.8, help="The sampling temperature" + ) + parser.add_argument( + "--prompt", + "-p", + type=str, + help="The prompt for generation", + default=None, + ) + + # Training args + parser.add_argument( + "--train", + action="store_true", + help="Do training", + ) + parser.add_argument( + "--data", + type=str, + default="data/", + help="Directory with {train, valid, test}.jsonl files", + ) + parser.add_argument( + "--lora-layers", + type=int, + default=16, + help="Number of layers to fine-tune", + ) + parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") + parser.add_argument( + "--iters", type=int, default=1000, help="Iterations to train for." + ) + parser.add_argument( + "--val-batches", + type=int, + default=25, + help="Number of validation batches, -1 uses the entire validation set.", + ) + parser.add_argument( + "--learning-rate", type=float, default=1e-5, help="Adam learning rate." + ) + parser.add_argument( + "--steps-per-report", + type=int, + default=10, + help="Number of training steps between loss reporting.", + ) + parser.add_argument( + "--steps-per-eval", + type=int, + default=200, + help="Number of training steps between validations.", + ) + parser.add_argument( + "--resume-adapter-file", + type=str, + default=None, + help="Load path to resume training with the given adapter weights.", + ) + parser.add_argument( + "--adapter-file", + type=str, + default="adapters.npz", + help="Save/load path for the trained adapter weights.", + ) + parser.add_argument( + "--save-every", + type=int, + default=100, + help="Save the model every N iterations.", + ) + parser.add_argument( + "--test", + action="store_true", + help="Evaluate on the test set after training", + ) + parser.add_argument( + "--test-batches", + type=int, + default=500, + help="Number of test set batches, -1 uses the entire test set.", + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + return parser + + +class Dataset: + """ + Light-weight wrapper to hold lines from a jsonl file + """ + + def __init__(self, path: Path, key: str = "text"): + if not path.exists(): + self._data = None + else: + with open(path, "r") as fid: + self._data = [json.loads(l) for l in fid] + self._key = key + + def __getitem__(self, idx: int): + return self._data[idx][self._key] + + def __len__(self): + if self._data is None: + return 0 + return len(self._data) + + +def load_dataset(args): + names = ("train", "valid", "test") + train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names) + if args.train and len(train) == 0: + raise ValueError( + "Training set not found or empty. Must provide training set for fine-tuning." + ) + if args.train and len(valid) == 0: + raise ValueError( + "Validation set not found or empty. Must provide validation set for fine-tuning." + ) + if args.test and len(test) == 0: + raise ValueError( + "Test set not found or empty. Must provide test set for evaluation." + ) + return train, valid, test + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + + np.random.seed(args.seed) + + print("Loading pretrained model") + model, tokenizer = load(args.model) + + if model.__class__ not in SUPPORTED_MODELS: + raise ValueError( + f"Model {model.__class__} not supported. " + f"Supported models: { SUPPORTED_MODELS}" + ) + + # Freeze all layers other than LORA linears + model.freeze() + for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) + if hasattr(l, "block_sparse_moe"): + l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) + + p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 + print(f"Total parameters {p:.3f}M") + p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + print(f"Trainable parameters {p:.3f}M") + + print("Loading datasets") + train_set, valid_set, test_set = load_dataset(args) + + # Resume training the given adapters. + if args.resume_adapter_file is not None: + print(f"Loading pretrained adapters from {args.resume_adapter_file}") + model.load_weights(args.resume_adapter_file, strict=False) + # init training args + trainingArgs = TrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=args.adapter_file, + ) + if args.train: + print("Training") + model.train() + opt = optim.Adam(learning_rate=args.learning_rate) + # Train model + train( + model=model, + tokenizer=tokenizer, + args=trainingArgs, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + ) + + # Load the LoRA adapter weights which we assume should exist by this point + if not Path(args.adapter_file).is_file(): + raise ValueError( + f"Adapter file {args.adapter_file} missing. " + "Use --train to learn and save the adapters.npz." + ) + model.load_weights(args.adapter_file, strict=False) + + if args.test: + print("Testing") + model.eval() + + test_loss = evaluate( + model=model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + ) + + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + + if args.prompt is not None: + print("Generating") + generate( + model=model, + tokenizer=tokenizer, + temp=args.temp, + max_tokens=args.max_tokens, + prompt=args.prompt, + verbose=True, + ) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index d35900cf..63401de1 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn +import numpy as np from .base import BaseModelArgs @@ -158,18 +159,34 @@ class MixtralSparseMoeBlock(nn.Module): x = x.reshape(-1, x.shape[-1]) gates = self.gate(x) - inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + + inds = mx.stop_gradient( + mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + ) # TODO remove it once we figure out how to fine tune TopK in MOE + scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1, ).astype(gates.dtype) - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) + if self.training: + mx.eval(inds) + inds = np.array(inds) + y = mx.zeros((x.shape[0], ne, x.shape[-1])) + for e, expert in enumerate(self.experts): + idx1, idx2 = map(mx.array, np.where(inds == e)) + if idx1.size == 0: + continue + y[idx1, idx2] = expert(x[idx1]) + + y = (y * scores[:, :, None]).sum(axis=1) + else: + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + y = mx.concatenate(y) return y.reshape(orig_shape) @@ -229,7 +246,7 @@ class MixtralModel(nn.Module): for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.norm(h[:, T - 1 : T, :]), cache + return self.norm(h), cache class Model(nn.Module): diff --git a/llms/mlx_lm/tuner/__init__.py b/llms/mlx_lm/tuner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py new file mode 100644 index 00000000..f0ec601b --- /dev/null +++ b/llms/mlx_lm/tuner/lora.py @@ -0,0 +1,88 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear( + input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale + ) + lora_lin.linear = linear + return lora_lin + + def to_linear(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + rank: int = 8, + bias: bool = False, + scale: float = 20.0, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, rank), + ) + self.lora_b = mx.zeros(shape=(rank, output_dims)) + + def __call__(self, x): + dtype = self.linear.weight.dtype + if isinstance(self.linear, nn.QuantizedLinear): + dtype = self.linear.scales.dtype + y = self.linear(x.astype(dtype)) + z = (x @ self.lora_a) @ self.lora_b + return y + self.scale * z diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py new file mode 100644 index 00000000..3c255703 --- /dev/null +++ b/llms/mlx_lm/tuner/trainer.py @@ -0,0 +1,204 @@ +import os +import time +from dataclasses import dataclass, field + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + + +@dataclass +class TrainingArgs: + lora_layers: int = field( + default=16, metadata={"help": "Number of layers to fine-tune"} + ) + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapter.npz", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits, _ = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + while True: + # Shuffle indices + indices = np.arange(len(dataset)) + indices = np.random.permutation(indices) + # Collect batches from dataset + for i in range(0, len(indices) - batch_size + 1, batch_size): + # Encode batch + batch = [ + tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size) + ] + lengths = [len(x) for x in batch] + + # Check if any sequence is longer than max_seq_length + if max(lengths) > max_seq_length: + print( + "[WARNING] Some sequences are longer than 2048 tokens. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the max length + batch_arr = np.zeros((batch_size, max(lengths)), np.int32) + + for j in range(batch_size): + batch_arr[j, : lengths[j]] = batch[j] + batch = mx.array(batch_arr) + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate( + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + loss: callable = default_loss, +): + all_losses = [] + ntokens = 0 + for it, batch in zip( + range(num_batches), + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, toks = loss(model, *batch) + all_losses.append((losses * toks).item()) + ntokens += toks.item() + + return np.sum(all_losses) / ntokens + + +def train( + model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: TrainingArgs = TrainingArgs(), + loss: callable = default_loss, +): + # Create value and grad function for loss + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = [] + n_tokens = 0 + print("Starting training..., iters:", args.iters) + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(args.iters), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # Model update + optimizer.update(model, grad) + + mx.eval(model.parameters(), optimizer.state, lvalue) + + # Record loss + losses.append(lvalue.item()) + n_tokens += toks.item() + + # Report training loss if needed + if (it + 1) % args.steps_per_report == 0: + train_loss = np.mean(losses) + + stop = time.perf_counter() + print( + f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"It/sec {args.steps_per_report / (stop - start):.3f}, " + f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" + ) + losses = [] + n_tokens = 0 + start = time.perf_counter() + + # Report validation loss if needed + if it == 0 or (it + 1) % args.steps_per_eval == 0: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + ) + print( + f"Iter {it + 1}: " + f"Val loss {val_loss:.3f}, " + f"Val took {(time.perf_counter() - stop):.3f}s" + ) + + start = time.perf_counter() + + # Save adapter weights if needed + if (it + 1) % args.steps_per_save == 0: + save_adapter(model=model, adapter_file=args.adapter_file) + print( + f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}." + ) + # save final adapter weights + save_adapter(model=model, adapter_file=args.adapter_file) + print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.") + + +def save_adapter( + model: nn.Module, + adapter_file: str, +): + flattened_tree = tree_flatten(model.trainable_parameters()) + + mx.savez(adapter_file, **dict(flattened_tree)) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py new file mode 100644 index 00000000..993f216a --- /dev/null +++ b/llms/mlx_lm/tuner/utils.py @@ -0,0 +1,22 @@ +import mlx.core as mx +from mlx.utils import tree_unflatten + +from .lora import LoRALinear + + +def apply_lora_layers(model, adapter_file: str): + adapters = list(mx.load(adapter_file).items()) + linear_replacements = {} + lora_layers = set( + [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters] + ) + + for name, module in model.named_modules(): + if name in lora_layers: + replacement_module = LoRALinear.from_linear(module) + linear_replacements[name] = replacement_module + + model.update_modules(tree_unflatten(list(linear_replacements.items()))) + + model.update(tree_unflatten(adapters)) + return model diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9a81e16..bc35f9f6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,13 +1,14 @@ +import copy import glob import json import logging from pathlib import Path -from typing import Generator, Tuple +from typing import Any, Dict, Generator, Tuple, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports from .models import llama, mixtral, phi2, plamo, qwen @@ -21,6 +22,7 @@ MODEL_MAPPING = { "qwen": qwen, "plamo": plamo, } +MAX_FILE_SIZE_GB = 15 linear_class_predicate = ( lambda m: isinstance(m, nn.Linear) @@ -204,6 +206,7 @@ def load_model(model_path: Path) -> nn.Module: mx.eval(model.parameters()) + model.eval() return model @@ -211,7 +214,7 @@ def load( path_or_hf_repo: str, tokenizer_config={} ) -> Tuple[nn.Module, PreTrainedTokenizer]: """ - Load the model from a given path or a huggingface repository. + Load the model and tokenizer from a given path or a huggingface repository. Args: model_path (Path): The path or the huggingface repository to load the model from. @@ -229,3 +232,103 @@ def load( model = load_model(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) return model, tokenizer + + +def fetch_from_hub( + model_path: Path, +) -> Tuple[Dict, dict, PreTrainedTokenizer]: + model = load_model(model_path) + + config = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + return model, config.to_dict(), tokenizer + + +def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: + """ + Splits the weights into smaller shards. + + Args: + weights (dict): Model weights. + max_file_size_gb (int): Maximum size of each shard in gigabytes. + + Returns: + list: List of weight shards. + """ + max_file_size_bytes = max_file_size_gb << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + estimated_size = v.size * v.dtype.size + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards + + +def upload_to_hub(path: str, upload_repo: str, hf_path: str): + """ + Uploads the model to Hugging Face hub. + + Args: + path (str): Local path to the model. + upload_repo (str): Name of the HF repo to upload to. + hf_path (str): Path to the original Hugging Face model. + """ + import os + + from huggingface_hub import HfApi, ModelCard, logging + + card = ModelCard.load(hf_path) + card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.text = f""" +# {upload_repo} +This model was converted to MLX format from [`{hf_path}`](). +Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. +## Use with mlx + +```bash +pip install mlx-lm +``` + +```python +from mlx_lm import load, generate + +model, tokenizer = load("{upload_repo}") +response = generate(model, tokenizer, prompt="hello", verbose=True) +``` +""" + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=upload_repo, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=upload_repo, + repo_type="model", + ) + + +def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: + """Save model weights into specified directory.""" + if isinstance(save_path, str): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + shards = make_shards(weights) + shards_count = len(shards) + shard_file_format = ( + "model-{:05d}-of-{:05d}.safetensors" + if shards_count > 1 + else "model.safetensors" + ) + + for i, shard in enumerate(shards): + shard_name = shard_file_format.format(i + 1, shards_count) + mx.save_safetensors(str(save_path / shard_name), shard) diff --git a/llms/phixtral/phixtral.py b/llms/phixtral/phixtral.py index 6774bc1f..0aa37ebf 100644 --- a/llms/phixtral/phixtral.py +++ b/llms/phixtral/phixtral.py @@ -4,7 +4,7 @@ import json import math from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -198,7 +198,7 @@ class Model(nn.Module): x: mx.array, mask: mx.array = None, cache: mx.array = None, - ) -> tuple[mx.array, mx.array]: + ) -> Tuple[mx.array, mx.array]: mask = None if x.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])