From 7b258f33ac097fd7035f805e733c4ba7efbdc851 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Jan 2024 11:14:52 -0800 Subject: [PATCH] Move lora example to use the same model format / conversion as `hf_llm` (#252) * huffing face the lora example to allow more models * fixes * comments * more readme nits * fusion + works better for qlora * nits' * comments --- llms/hf_llm/README.md | 4 +- llms/hf_llm/generate.py | 1 - llms/hf_llm/models.py | 5 +- lora/README.md | 92 ++++++++++---- lora/convert.py | 92 +++++--------- lora/fuse.py | 80 ++++++++++++ lora/lora.py | 105 ++++------------ lora/models.py | 271 +++++++++++++++++++++++++++++++++------- lora/requirements.txt | 5 +- lora/utils.py | 90 +++++++++++++ 10 files changed, 521 insertions(+), 224 deletions(-) create mode 100644 lora/fuse.py create mode 100644 lora/utils.py diff --git a/llms/hf_llm/README.md b/llms/hf_llm/README.md index 809417cc..b7762be3 100644 --- a/llms/hf_llm/README.md +++ b/llms/hf_llm/README.md @@ -60,7 +60,7 @@ You can convert (change the data type or quantize) models using the `convert.py` script. This script takes a Hugging Face repo as input and outputs a model directory (which you can optionally also upload to Hugging Face). -For example, to make 4-bit quantized a model, run: +For example, to make a 4-bit quantized model, run: ``` python convert.py --hf-path -q @@ -73,5 +73,5 @@ python convert.py --help ``` You can upload new models to the [Hugging Face MLX -Community](https://huggingface.co/mlx-community) by specifying `--upload-name`` +Community](https://huggingface.co/mlx-community) by specifying `--upload-name` to `convert.py`. diff --git a/llms/hf_llm/generate.py b/llms/hf_llm/generate.py index 0b8f7ea2..514d8e4d 100644 --- a/llms/hf_llm/generate.py +++ b/llms/hf_llm/generate.py @@ -39,7 +39,6 @@ def generate( tic = time.time() tokens.append(token.item()) - # if (n + 1) % 10 == 0: s = tokenizer.decode(tokens) print(s[skip:], end="", flush=True) skip = len(s) diff --git a/llms/hf_llm/models.py b/llms/hf_llm/models.py index a706e4ea..d0f71e1b 100644 --- a/llms/hf_llm/models.py +++ b/llms/hf_llm/models.py @@ -10,7 +10,6 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_unflatten from transformers import AutoTokenizer @@ -250,9 +249,7 @@ def load(path_or_hf_repo: str): model.load_weights(list(weights.items())) mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained( - model_path, - ) + tokenizer = AutoTokenizer.from_pretrained(model_path) return model, tokenizer diff --git a/lora/README.md b/lora/README.md index c2980086..7581aced 100644 --- a/lora/README.md +++ b/lora/README.md @@ -1,8 +1,9 @@ # Fine-Tuning with LoRA or QLoRA -This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a -Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target -task. The example also supports quantized LoRA (QLoRA).[^qlora] +This is an example of using MLX to fine-tune an LLM with low rank adaptation +(LoRA) for a target task.[^lora] The example also supports quantized LoRA +(QLoRA).[^qlora] The example works with Llama and Mistral style +models available on Hugging Face. In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to generate SQL queries from natural language. However, the example is intended to @@ -11,11 +12,13 @@ be general should you wish to use a custom dataset. ## Contents * [Setup](#Setup) + * [Convert](#convert) * [Run](#Run) * [Fine-tune](#Fine-tune) * [Evaluate](#Evaluate) * [Generate](#Generate) * [Results](#Results) +* [Fuse and Upload](#Fuse-and-Upload) * [Custom Data](#Custom-Data) * [Memory Issues](#Memory-Issues) @@ -28,36 +31,49 @@ Install the dependencies: pip install -r requirements.txt ``` -Next, download and convert the model. The Mistral weights can be downloaded with: +### Convert + +This step is optional if you want to quantize (for QLoRA) or change the default +data type of a pre-existing model. + +You convert models using the `convert.py` script. This script takes a Hugging +Face repo as input and outputs a model directory (which you can optionally also +upload to Hugging Face). + +To make a 4-bit quantized model, run: ``` -curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar -tar -xf mistral-7B-v0.1.tar +python convert.py --hf-path -q ``` -If you do not have access to the Llama weights you will need to [request -access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) -from Meta. - -Convert the model with: +For example, the following will make a 4-bit quantized Mistral 7B and by default +store it in `mlx_model`: ``` -python convert.py \ - --torch-path \ - --mlx-path +python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q ``` -If you wish to use QLoRA, then convert the model with 4-bit quantization using -the `-q` option. +For more options run: + +``` +python convert.py --help +``` + +You can upload new models to the [Hugging Face MLX +Community](https://huggingface.co/mlx-community) by specifying `--upload-name` +to `convert.py`. ## Run -The main script is `lora.py`. To see a full list of options run +The main script is `lora.py`. To see a full list of options run: ``` python lora.py --help ``` +Note, in the following the `--model` argument can be any compatible Hugging +Face repo or a local path to a converted mdoel. + ### Fine-tune To fine-tune a model use: @@ -71,9 +87,6 @@ python lora.py --model \ If `--model` points to a quantized model, then the training will use QLoRA, otherwise it will use regular LoRA. -Note, the model path should have the MLX weights, the tokenizer, and the -`config.json` which will all be output by the `convert.py` script. - By default, the adapter weights are saved in `adapters.npz`. You can specify the output location with `--adapter-file`. @@ -82,7 +95,7 @@ You can resume fine-tuning with an existing adapter with `--resume-adapter-file ### Evaluate -To compute test set perplexity use +To compute test set perplexity use: ``` python lora.py --model \ @@ -92,7 +105,7 @@ python lora.py --model \ ### Generate -For generation use +For generation use: ``` python lora.py --model \ @@ -121,6 +134,37 @@ training and validation loss at a few points over the course of training. The model trains at around 475 tokens per second on an M2 Ultra. +## Fuse and Upload + +You can generate a fused model with the low-rank adapters included using the +`fuse.py` script. This script also optionally allows you to upload the fused +model to the [Hugging Face MLX +Community](https://huggingface.co/mlx-community). + +To generate the fused model run: + +``` +python fuse.py +``` + +This will by default load the base model from `mlx_model/`, the adapters from +`adapters.npz`, and save the fused model in the path `lora_fused_model/`. All +of these are configurable. You can see the list of options with: + +``` +python fuse.py --help +``` + +To upload a fused model, supply the `--upload-name` and `--hf-path` arguments +to `fuse.py`. The latter is the repo name of the original model, which is +useful for the sake of attribution and model versioning. + +For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: + +``` +python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1 +``` + ## Custom Data You can make your own dataset for fine-tuning with LoRA. You can specify the @@ -164,7 +208,7 @@ For example, for a machine with 32 GB the following should run reasonably fast: ``` python lora.py \ - --model \ + --model mistralai/Mistral-7B-v0.1 \ --train \ --batch-size 1 \ --lora-layers 4 @@ -175,6 +219,4 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) -[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. -[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. diff --git a/lora/convert.py b/lora/convert.py index 38342080..98697587 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -2,33 +2,23 @@ import argparse import copy -import json -import shutil -from pathlib import Path import mlx.core as mx import mlx.nn as nn -import numpy as np -import torch -from mlx.utils import tree_flatten, tree_map, tree_unflatten - -from lora import Model, ModelArgs +import utils +from mlx.utils import tree_flatten +from models import Model, ModelArgs def quantize(weights, config, args): quantized_config = copy.deepcopy(config) # Load the model: - model = Model(ModelArgs(**config)) - weights = tree_map(mx.array, weights) - model.update(tree_unflatten(list(weights.items()))) + model = Model(ModelArgs.from_dict(config)) + model.load_weights(list(weights.items())) # Quantize the model: - nn.QuantizedLinear.quantize_module( - model, - args.q_group_size, - args.q_bits, - ) + nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) # Update the config: quantized_config["quantization"] = { @@ -42,19 +32,18 @@ def quantize(weights, config, args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Convert Mistral or Llama models to MLX.", + description="Convert Hugging Face model to MLX format" ) parser.add_argument( - "--torch-path", + "--hf-path", type=str, - default="mistral-7B-v0.1/", - help="Path to the torch model directory", + help="Path to the Hugging Face model.", ) parser.add_argument( "--mlx-path", type=str, - default="mlx_model/", - help="The directory to store the mlx model", + default="mlx_model", + help="Path to save the MLX model.", ) parser.add_argument( "-q", @@ -74,50 +63,31 @@ if __name__ == "__main__": type=int, default=4, ) - args = parser.parse_args() - - args = parser.parse_args() - - torch_path = Path(args.torch_path) - mlx_path = Path(args.mlx_path) - mlx_path.mkdir(parents=True, exist_ok=True) - - # Copy the tokenizer - tokenizer_path = torch_path / "tokenizer.model" - if not tokenizer_path.exists(): - print(f"Make sure there is a file tokenizer.model in {args.torch_path}") - exit(0) - shutil.copyfile( - str(tokenizer_path), - str(mlx_path / "tokenizer.model"), + parser.add_argument( + "--dtype", + help="Type to save the parameters, ignored if -q is given.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float16", + ) + parser.add_argument( + "--upload-name", + help="The name of model to upload to Hugging Face MLX Community", + type=str, + default=None, ) - # Load the torch model weights to numpy: - weights = torch.load(str(torch_path / "consolidated.00.pth")) - for k, v in weights.items(): - weights[k] = v.to(torch.float16).numpy() + args = parser.parse_args() - # Standardize the params - with open(torch_path / "params.json", "r") as f: - config = json.loads(f.read()) - unused = ["multiple_of", "sliding_window"] - for k in unused: - config.pop(k, None) - n_heads = config["n_heads"] - if "n_kv_heads" not in config: - config["n_kv_heads"] = n_heads - if "head_dim" not in config: - config["head_dim"] = config["dim"] // n_heads - if "hidden_dim" not in config: - config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] - if config.get("vocab_size", -1) < 0: - config["vocab_size"] = weights["output.weight"].shape[0] + print("[INFO] Loading") + weights, config, tokenizer = utils.fetch_from_hub(args.hf_path) + dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) + weights = {k: v.astype(dtype) for k, v in weights.items()} if args.quantize: print("[INFO] Quantizing") weights, config = quantize(weights, config, args) - np.savez(str(mlx_path / "weights.npz"), **weights) - - with open(mlx_path / "config.json", "w") as outfile: - json.dump(config, outfile, indent=4) + utils.save_model(args.mlx_path, weights, tokenizer, config) + if args.upload_name is not None: + utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path) diff --git a/lora/fuse.py b/lora/fuse.py new file mode 100644 index 00000000..c2c9624a --- /dev/null +++ b/lora/fuse.py @@ -0,0 +1,80 @@ +# Copyright © 2023 Apple Inc. + +import argparse +from pathlib import Path + +import mlx.core as mx +import models +import utils +from mlx.utils import tree_flatten, tree_unflatten + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") + parser.add_argument( + "--model", + default="mlx_model", + help="The path to the local model directory or Hugging Face repo.", + ) + parser.add_argument( + "--save-path", + default="lora_fused_model", + help="The path to save the fused model.", + ) + parser.add_argument( + "--adapter-file", + type=str, + default="adapters.npz", + help="Path to the trained adapter weights (npz or safetensors).", + ) + parser.add_argument( + "--hf-path", + help=( + "Path to the original Hugging Face model. This is " + "required for upload if --model is a local directory." + ), + type=str, + default=None, + ) + parser.add_argument( + "--upload-name", + help="The name of model to upload to Hugging Face MLX Community", + type=str, + default=None, + ) + + print("Loading pretrained model") + args = parser.parse_args() + + model, tokenizer, config = models.load(args.model) + + # Load adapters and get number of LoRA layers + adapters = list(mx.load(args.adapter_file).items()) + lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) + + # Freeze all layers other than LORA linears + model.freeze() + for l in model.model.layers[-lora_layers:]: + l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj) + + model.update(tree_unflatten(adapters)) + fused_linears = [ + (n, m.to_linear()) + for n, m in model.named_modules() + if isinstance(m, models.LoRALinear) + ] + + model.update_modules(tree_unflatten(fused_linears)) + weights = dict(tree_flatten(model.parameters())) + utils.save_model(args.save_path, weights, tokenizer._tokenizer, config) + + if args.upload_name is not None: + hf_path = args.hf_path + if not Path(args.model).exists(): + # If the model path doesn't exist, assume it's an HF repo + hf_path = args.model + elif hf_path is None: + raise ValueError( + "Must provide original Hugging Face repo to upload local model." + ) + utils.upload_to_hub(args.save_path, args.upload_name, hf_path) diff --git a/lora/lora.py b/lora/lora.py index 1c8fe17f..e895a87b 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -5,15 +5,13 @@ import json import math import time from pathlib import Path -from typing import List import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim +import models import numpy as np from mlx.utils import tree_flatten, tree_unflatten -from models import LoRALinear, Model, ModelArgs -from sentencepiece import SentencePieceProcessor def build_parser(): @@ -21,7 +19,7 @@ def build_parser(): parser.add_argument( "--model", default="mlx_model", - help="A path to the model files containing the tokenizer, weights, config.", + help="The path to the local model directory or Hugging Face repo.", ) # Generation args parser.add_argument( @@ -111,34 +109,6 @@ def build_parser(): return parser -class Tokenizer: - def __init__(self, model_path: str): - assert Path(model_path).exists(), model_path - self._model = SentencePieceProcessor(model_file=model_path) - self._sep = "▁" - assert self._model.vocab_size() == self._model.get_piece_size() - - def encode(self, s: str, eos: bool = False) -> List[int]: - toks = [self._model.bos_id(), *self._model.encode(s)] - if eos: - toks.append(self.eos_id) - return toks - - @property - def eos_id(self) -> int: - return self._model.eos_id() - - def decode(self, t: List[int]) -> str: - out = self._model.decode(t) - if t and self._model.id_to_piece(t[0])[0] == self._sep: - return " " + out - return out - - @property - def vocab_size(self) -> int: - return self._model.vocab_size() - - class Dataset: """ Light-weight wrapper to hold lines from a jsonl file @@ -295,56 +265,27 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): def generate(model, prompt, tokenizer, args): print(args.prompt, end="", flush=True) - prompt = mx.array(tokenizer.encode(args.prompt)) - def generate_step(): - temp = args.temp - - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - logits, cache = model(prompt[None]) - y = sample(logits[:, -1, :]) - yield y - - while True: - logits, cache = model(y[:, None], cache) - y = sample(logits.squeeze(1)) - yield y + prompt = tokenizer.encode(args.prompt) tokens = [] - for token, _ in zip(generate_step(), range(args.num_tokens)): - tokens.append(token) + skip = 0 + for token, n in zip( + models.generate(prompt, model, args.temp), + range(args.max_tokens), + ): + if token == tokenizer.eos_token_id: + break - if (len(tokens) % 10) == 0: - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s, end="", flush=True) - tokens = [] - - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s, flush=True) - - -def load_model(folder: str): - model_path = Path(folder) - tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - quantization = config.pop("quantization", None) - model_args = ModelArgs(**config) - model = Model(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - - weights = mx.load(str(model_path / "weights.npz")) - weights = tree_unflatten(list(weights.items())) - model.update(weights) - return model, tokenizer + tokens.append(token.item()) + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + print(tokenizer.decode(tokens)[skip:], flush=True) + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return if __name__ == "__main__": @@ -354,13 +295,13 @@ if __name__ == "__main__": np.random.seed(args.seed) print("Loading pretrained model") - model, tokenizer = load_model(args.model) + model, tokenizer, _ = models.load(args.model) # Freeze all layers other than LORA linears model.freeze() - for l in model.layers[-args.lora_layers :]: - l.attention.wq = LoRALinear.from_linear(l.attention.wq) - l.attention.wv = LoRALinear.from_linear(l.attention.wv) + for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: + l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") diff --git a/lora/models.py b/lora/models.py index 3208de35..3b7d4848 100644 --- a/lora/models.py +++ b/lora/models.py @@ -1,23 +1,81 @@ # Copyright © 2023 Apple Inc. + +import glob +import inspect +import json import math from dataclasses import dataclass -from typing import List, Optional, Tuple +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_map, tree_unflatten +import numpy as np +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer @dataclass class ModelArgs: - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Tokenizer: + def __init__(self, model_path: str): + self._tokenizer = AutoTokenizer.from_pretrained(model_path) + self._eos = self._tokenizer.eos_token_id + self._bos = self._tokenizer.bos_token_id + + def encode(self, s: str, eos: bool = False) -> mx.array: + toks = self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )[ + "input_ids" + ][0] + if eos: + toks = np.concatenate([toks, [self._eos]]) + return mx.array(toks) + + @property + def eos_id(self) -> int: + return self._eos + + def decode(self, t: List[int]) -> str: + return self._tokenizer.decode(t) class LoRALinear(nn.Module): @@ -32,14 +90,58 @@ class LoRALinear(nn.Module): lora_lin.linear = linear return lora_lin + def to_linear(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + def __init__( - self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False + self, + input_dims: int, + output_dims: int, + lora_rank: int = 8, + bias: bool = False, + scale: float = 20.0, ): super().__init__() # Regular linear layer weights self.linear = nn.Linear(input_dims, output_dims, bias=bias) + # Scale for low-rank update + self.scale = scale + # Low rank lora weights scale = 1 / math.sqrt(input_dims) self.lora_a = mx.random.uniform( @@ -55,7 +157,7 @@ class LoRALinear(nn.Module): dtype = self.linear.scales.dtype y = self.linear(x.astype(dtype)) z = (x @ self.lora_a) @ self.lora_b - return y + 2.0 * z + return y + self.scale * z class RMSNorm(nn.Module): @@ -75,20 +177,31 @@ class RMSNorm(nn.Module): class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.args = args - self.n_heads: int = args.n_heads - self.n_kv_heads: int = args.n_kv_heads + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = self.n_heads // self.n_kv_heads + self.repeats = n_heads // n_kv_heads - self.scale = self.args.head_dim**-0.5 + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 - self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) - self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) def __call__( self, @@ -98,7 +211,7 @@ class Attention(nn.Module): ) -> mx.array: B, L, D = x.shape - queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) @@ -127,30 +240,29 @@ class Attention(nn.Module): scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.wo(output), (keys, values) + return self.o_proj(output), (keys, values) -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): super().__init__() - - self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.args = args def __call__( @@ -159,31 +271,32 @@ class TransformerBlock(nn.Module): mask: Optional[mx.array] = None, cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: - r, cache = self.attention(self.attention_norm(x), mask, cache) + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r - r = self.feed_forward(self.ffn_norm(h)) + r = self.mlp(self.post_attention_layernorm(h)) out = h + r return out, cache -class Model(nn.Module): +class LlamaModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size - self.n_layers = args.n_layers + self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, inputs: mx.array, cache=None, ): - h = self.tok_embeddings(inputs) + h = self.embed_tokens(inputs) mask = None if h.shape[1] > 1: @@ -196,4 +309,70 @@ class Model(nn.Module): for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.output(self.norm(h)), cache + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = LlamaModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + +def load(path_or_hf_repo: str): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + model_args = ModelArgs.from_dict(config) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model = Model(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + return model, Tokenizer(model_path), config + + +def generate(prompt: mx.array, model: Model, temp: float = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y diff --git a/lora/requirements.txt b/lora/requirements.txt index 4ab8df9e..9abc3e88 100644 --- a/lora/requirements.txt +++ b/lora/requirements.txt @@ -1,4 +1,3 @@ mlx>=0.0.7 -sentencepiece -torch -numpy \ No newline at end of file +transformers +numpy diff --git a/lora/utils.py b/lora/utils.py new file mode 100644 index 00000000..182c66dc --- /dev/null +++ b/lora/utils.py @@ -0,0 +1,90 @@ +# Copyright © 2023 Apple Inc. + +import glob +import json +from pathlib import Path + +import mlx.core as mx +import transformers +from huggingface_hub import snapshot_download + + +def fetch_from_hub(hf_path: str): + model_path = snapshot_download( + repo_id=hf_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + config = transformers.AutoConfig.from_pretrained(hf_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_path, + ) + return weights, config.to_dict(), tokenizer + + +def upload_to_hub(path: str, name: str, hf_path: str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"mlx-community/{name}" + + card = ModelCard.load(hf_path) + card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.text = f""" +# {name} +This model was converted to MLX format from [`{hf_path}`](). +Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model. +## Use with mlx +```bash +pip install mlx +git clone https://github.com/ml-explore/mlx-examples.git +cd mlx-examples/llms/hf_llm +python generate.py --model {repo_id} --prompt "My name is" +``` +""" + card.save(os.path.join(path, "README.md")) + + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True) + api.upload_folder( + folder_path=path, + repo_id=repo_id, + repo_type="model", + ) + + +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + estimated_size = v.size * v.dtype.size + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards + + +def save_model(save_dir: str, weights, tokenizer, config): + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + shards = make_shards(weights) + for i, shard in enumerate(shards): + # TODO use HF file name scheme for simplicity + mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard) + tokenizer.save_pretrained(save_dir) + with open(save_dir / "config.json", "w") as fid: + json.dump(config, fid, indent=4)