Move lora example to use the same model format / conversion as hf_llm (#252)

* huffing face the lora example to allow more models

* fixes

* comments

* more readme nits

* fusion + works better for qlora

* nits'

* comments
This commit is contained in:
Awni Hannun 2024-01-09 11:14:52 -08:00 committed by GitHub
parent bbd7172eef
commit 7b258f33ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 521 additions and 224 deletions

View File

@ -60,7 +60,7 @@ You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs `convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face). a model directory (which you can optionally also upload to Hugging Face).
For example, to make 4-bit quantized a model, run: For example, to make a 4-bit quantized model, run:
``` ```
python convert.py --hf-path <hf_repo> -q python convert.py --hf-path <hf_repo> -q
@ -73,5 +73,5 @@ python convert.py --help
``` ```
You can upload new models to the [Hugging Face MLX You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`` Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
to `convert.py`. to `convert.py`.

View File

@ -39,7 +39,6 @@ def generate(
tic = time.time() tic = time.time()
tokens.append(token.item()) tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True) print(s[skip:], end="", flush=True)
skip = len(s) skip = len(s)

View File

@ -10,7 +10,6 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -250,9 +249,7 @@ def load(path_or_hf_repo: str):
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
mx.eval(model.parameters()) mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_path)
model_path,
)
return model, tokenizer return model, tokenizer

View File

@ -1,8 +1,9 @@
# Fine-Tuning with LoRA or QLoRA # Fine-Tuning with LoRA or QLoRA
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a This is an example of using MLX to fine-tune an LLM with low rank adaptation
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target (LoRA) for a target task.[^lora] The example also supports quantized LoRA
task. The example also supports quantized LoRA (QLoRA).[^qlora] (QLoRA).[^qlora] The example works with Llama and Mistral style
models available on Hugging Face.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to generate SQL queries from natural language. However, the example is intended to
@ -11,11 +12,13 @@ be general should you wish to use a custom dataset.
## Contents ## Contents
* [Setup](#Setup) * [Setup](#Setup)
* [Convert](#convert)
* [Run](#Run) * [Run](#Run)
* [Fine-tune](#Fine-tune) * [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate) * [Evaluate](#Evaluate)
* [Generate](#Generate) * [Generate](#Generate)
* [Results](#Results) * [Results](#Results)
* [Fuse and Upload](#Fuse-and-Upload)
* [Custom Data](#Custom-Data) * [Custom Data](#Custom-Data)
* [Memory Issues](#Memory-Issues) * [Memory Issues](#Memory-Issues)
@ -28,36 +31,49 @@ Install the dependencies:
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Next, download and convert the model. The Mistral weights can be downloaded with: ### Convert
This step is optional if you want to quantize (for QLoRA) or change the default
data type of a pre-existing model.
You convert models using the `convert.py` script. This script takes a Hugging
Face repo as input and outputs a model directory (which you can optionally also
upload to Hugging Face).
To make a 4-bit quantized model, run:
``` ```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar python convert.py --hf-path <hf_repo> -q
tar -xf mistral-7B-v0.1.tar
``` ```
If you do not have access to the Llama weights you will need to [request For example, the following will make a 4-bit quantized Mistral 7B and by default
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) store it in `mlx_model`:
from Meta.
Convert the model with:
``` ```
python convert.py \ python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q
--torch-path <path_to_torch_model> \
--mlx-path <path_to_mlx_model>
``` ```
If you wish to use QLoRA, then convert the model with 4-bit quantization using For more options run:
the `-q` option.
```
python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
to `convert.py`.
## Run ## Run
The main script is `lora.py`. To see a full list of options run The main script is `lora.py`. To see a full list of options run:
``` ```
python lora.py --help python lora.py --help
``` ```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted mdoel.
### Fine-tune ### Fine-tune
To fine-tune a model use: To fine-tune a model use:
@ -71,9 +87,6 @@ python lora.py --model <path_to_model> \
If `--model` points to a quantized model, then the training will use QLoRA, If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA. otherwise it will use regular LoRA.
Note, the model path should have the MLX weights, the tokenizer, and the
`config.json` which will all be output by the `convert.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`. the output location with `--adapter-file`.
@ -82,7 +95,7 @@ You can resume fine-tuning with an existing adapter with `--resume-adapter-file
### Evaluate ### Evaluate
To compute test set perplexity use To compute test set perplexity use:
``` ```
python lora.py --model <path_to_model> \ python lora.py --model <path_to_model> \
@ -92,7 +105,7 @@ python lora.py --model <path_to_model> \
### Generate ### Generate
For generation use For generation use:
``` ```
python lora.py --model <path_to_model> \ python lora.py --model <path_to_model> \
@ -121,6 +134,37 @@ training and validation loss at a few points over the course of training.
The model trains at around 475 tokens per second on an M2 Ultra. The model trains at around 475 tokens per second on an M2 Ultra.
## Fuse and Upload
You can generate a fused model with the low-rank adapters included using the
`fuse.py` script. This script also optionally allows you to upload the fused
model to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community).
To generate the fused model run:
```
python fuse.py
```
This will by default load the base model from `mlx_model/`, the adapters from
`adapters.npz`, and save the fused model in the path `lora_fused_model/`. All
of these are configurable. You can see the list of options with:
```
python fuse.py --help
```
To upload a fused model, supply the `--upload-name` and `--hf-path` arguments
to `fuse.py`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```
python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
```
## Custom Data ## Custom Data
You can make your own dataset for fine-tuning with LoRA. You can specify the You can make your own dataset for fine-tuning with LoRA. You can specify the
@ -164,7 +208,7 @@ For example, for a machine with 32 GB the following should run reasonably fast:
``` ```
python lora.py \ python lora.py \
--model <path_to_model> \ --model mistralai/Mistral-7B-v0.1 \
--train \ --train \
--batch-size 1 \ --batch-size 1 \
--lora-layers 4 --lora-layers 4
@ -175,6 +219,4 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) [^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@ -2,33 +2,23 @@
import argparse import argparse
import copy import copy
import json
import shutil
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import utils
import torch from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_map, tree_unflatten from models import Model, ModelArgs
from lora import Model, ModelArgs
def quantize(weights, config, args): def quantize(weights, config, args):
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
# Load the model: # Load the model:
model = Model(ModelArgs(**config)) model = Model(ModelArgs.from_dict(config))
weights = tree_map(mx.array, weights) model.load_weights(list(weights.items()))
model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module( nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
model,
args.q_group_size,
args.q_bits,
)
# Update the config: # Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {
@ -42,19 +32,18 @@ def quantize(weights, config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Convert Mistral or Llama models to MLX.", description="Convert Hugging Face model to MLX format"
) )
parser.add_argument( parser.add_argument(
"--torch-path", "--hf-path",
type=str, type=str,
default="mistral-7B-v0.1/", help="Path to the Hugging Face model.",
help="Path to the torch model directory",
) )
parser.add_argument( parser.add_argument(
"--mlx-path", "--mlx-path",
type=str, type=str,
default="mlx_model/", default="mlx_model",
help="The directory to store the mlx model", help="Path to save the MLX model.",
) )
parser.add_argument( parser.add_argument(
"-q", "-q",
@ -74,50 +63,31 @@ if __name__ == "__main__":
type=int, type=int,
default=4, default=4,
) )
args = parser.parse_args() parser.add_argument(
"--dtype",
args = parser.parse_args() help="Type to save the parameters, ignored if -q is given.",
type=str,
torch_path = Path(args.torch_path) choices=["float16", "bfloat16", "float32"],
mlx_path = Path(args.mlx_path) default="float16",
mlx_path.mkdir(parents=True, exist_ok=True) )
parser.add_argument(
# Copy the tokenizer "--upload-name",
tokenizer_path = torch_path / "tokenizer.model" help="The name of model to upload to Hugging Face MLX Community",
if not tokenizer_path.exists(): type=str,
print(f"Make sure there is a file tokenizer.model in {args.torch_path}") default=None,
exit(0)
shutil.copyfile(
str(tokenizer_path),
str(mlx_path / "tokenizer.model"),
) )
# Load the torch model weights to numpy: args = parser.parse_args()
weights = torch.load(str(torch_path / "consolidated.00.pth"))
for k, v in weights.items():
weights[k] = v.to(torch.float16).numpy()
# Standardize the params print("[INFO] Loading")
with open(torch_path / "params.json", "r") as f: weights, config, tokenizer = utils.fetch_from_hub(args.hf_path)
config = json.loads(f.read())
unused = ["multiple_of", "sliding_window"]
for k in unused:
config.pop(k, None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[0]
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize: if args.quantize:
print("[INFO] Quantizing") print("[INFO] Quantizing")
weights, config = quantize(weights, config, args) weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights) utils.save_model(args.mlx_path, weights, tokenizer, config)
if args.upload_name is not None:
with open(mlx_path / "config.json", "w") as outfile: utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path)
json.dump(config, outfile, indent=4)

80
lora/fuse.py Normal file
View File

@ -0,0 +1,80 @@
# Copyright © 2023 Apple Inc.
import argparse
from pathlib import Path
import mlx.core as mx
import models
import utils
from mlx.utils import tree_flatten, tree_unflatten
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Path to the trained adapter weights (npz or safetensors).",
)
parser.add_argument(
"--hf-path",
help=(
"Path to the original Hugging Face model. This is "
"required for upload if --model is a local directory."
),
type=str,
default=None,
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
print("Loading pretrained model")
args = parser.parse_args()
model, tokenizer, config = models.load(args.model)
# Load adapters and get number of LoRA layers
adapters = list(mx.load(args.adapter_file).items())
lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]])
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
model.update(tree_unflatten(adapters))
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, models.LoRALinear)
]
model.update_modules(tree_unflatten(fused_linears))
weights = dict(tree_flatten(model.parameters()))
utils.save_model(args.save_path, weights, tokenizer._tokenizer, config)
if args.upload_name is not None:
hf_path = args.hf_path
if not Path(args.model).exists():
# If the model path doesn't exist, assume it's an HF repo
hf_path = args.model
elif hf_path is None:
raise ValueError(
"Must provide original Hugging Face repo to upload local model."
)
utils.upload_to_hub(args.save_path, args.upload_name, hf_path)

View File

@ -5,15 +5,13 @@ import json
import math import math
import time import time
from pathlib import Path from pathlib import Path
from typing import List
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import models
import numpy as np import numpy as np
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from models import LoRALinear, Model, ModelArgs
from sentencepiece import SentencePieceProcessor
def build_parser(): def build_parser():
@ -21,7 +19,7 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx_model", default="mlx_model",
help="A path to the model files containing the tokenizer, weights, config.", help="The path to the local model directory or Hugging Face repo.",
) )
# Generation args # Generation args
parser.add_argument( parser.add_argument(
@ -111,34 +109,6 @@ def build_parser():
return parser return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str, eos: bool = False) -> List[int]:
toks = [self._model.bos_id(), *self._model.encode(s)]
if eos:
toks.append(self.eos_id)
return toks
@property
def eos_id(self) -> int:
return self._model.eos_id()
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
class Dataset: class Dataset:
""" """
Light-weight wrapper to hold lines from a jsonl file Light-weight wrapper to hold lines from a jsonl file
@ -295,56 +265,27 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args): def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step(): prompt = tokenizer.encode(args.prompt)
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
tokens = [] tokens = []
for token, _ in zip(generate_step(), range(args.num_tokens)): skip = 0
tokens.append(token) for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if (len(tokens) % 10) == 0: tokens.append(token.item())
mx.eval(tokens) s = tokenizer.decode(tokens)
s = tokenizer.decode([t.item() for t in tokens]) print(s[skip:], end="", flush=True)
print(s, end="", flush=True) skip = len(s)
tokens = [] print(tokenizer.decode(tokens)[skip:], flush=True)
print("=" * 10)
mx.eval(tokens) if len(tokens) == 0:
s = tokenizer.decode([t.item() for t in tokens]) print("No tokens generated for this prompt")
print(s, flush=True) return
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model.update(weights)
return model, tokenizer
if __name__ == "__main__": if __name__ == "__main__":
@ -354,13 +295,13 @@ if __name__ == "__main__":
np.random.seed(args.seed) np.random.seed(args.seed)
print("Loading pretrained model") print("Loading pretrained model")
model, tokenizer = load_model(args.model) model, tokenizer, _ = models.load(args.model)
# Freeze all layers other than LORA linears # Freeze all layers other than LORA linears
model.freeze() model.freeze()
for l in model.layers[-args.lora_layers :]: for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.attention.wq = LoRALinear.from_linear(l.attention.wq) l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.attention.wv = LoRALinear.from_linear(l.attention.wv) l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M") print(f"Total parameters {p:.3f}M")

View File

@ -1,23 +1,81 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import glob
import inspect
import json
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten import numpy as np
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@dataclass @dataclass
class ModelArgs: class ModelArgs:
dim: int hidden_size: int
n_layers: int num_hidden_layers: int
head_dim: int intermediate_size: int
hidden_dim: int num_attention_heads: int
n_heads: int rms_norm_eps: float
n_kv_heads: int
norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Tokenizer:
def __init__(self, model_path: str):
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._eos = self._tokenizer.eos_token_id
self._bos = self._tokenizer.bos_token_id
def encode(self, s: str, eos: bool = False) -> mx.array:
toks = self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
if eos:
toks = np.concatenate([toks, [self._eos]])
return mx.array(toks)
@property
def eos_id(self) -> int:
return self._eos
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t)
class LoRALinear(nn.Module): class LoRALinear(nn.Module):
@ -32,14 +90,58 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear lora_lin.linear = linear
return lora_lin return lora_lin
def to_linear(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__( def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False self,
input_dims: int,
output_dims: int,
lora_rank: int = 8,
bias: bool = False,
scale: float = 20.0,
): ):
super().__init__() super().__init__()
# Regular linear layer weights # Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias) self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights # Low rank lora weights
scale = 1 / math.sqrt(input_dims) scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform( self.lora_a = mx.random.uniform(
@ -55,7 +157,7 @@ class LoRALinear(nn.Module):
dtype = self.linear.scales.dtype dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype)) y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z return y + self.scale * z
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -75,20 +177,31 @@ class RMSNorm(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.n_heads: int = args.n_heads dim = args.hidden_size
self.n_kv_heads: int = args.n_kv_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads self.repeats = n_heads // n_kv_heads
self.scale = self.args.head_dim**-0.5 head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True) rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__( def __call__(
self, self,
@ -98,7 +211,7 @@ class Attention(nn.Module):
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x) queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation # Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
@ -127,30 +240,29 @@ class Attention(nn.Module):
scores += mask scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values) return self.o_proj(output), (keys, values)
class FeedForward(nn.Module): class MLP(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, dim, hidden_dim):
super().__init__() super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x)) return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.n_heads = args.n_heads self.num_attention_heads = args.num_attention_heads
self.dim = args.dim self.hidden_size = args.hidden_size
self.attention = Attention(args) self.self_attn = Attention(args)
self.feed_forward = FeedForward(args=args) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args self.args = args
def __call__( def __call__(
@ -159,31 +271,32 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache) r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.feed_forward(self.ffn_norm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r out = h + r
return out, cache return out, cache
class Model(nn.Module): class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.n_layers = args.n_layers self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0 assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [
self.norm = RMSNorm(args.dim, eps=args.norm_eps) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
cache=None, cache=None,
): ):
h = self.tok_embeddings(inputs) h = self.embed_tokens(inputs)
mask = None mask = None
if h.shape[1] > 1: if h.shape[1] > 1:
@ -196,4 +309,70 @@ class Model(nn.Module):
for e, layer in enumerate(self.layers): for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e]) h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
return model, Tokenizer(model_path), config
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

View File

@ -1,4 +1,3 @@
mlx>=0.0.7 mlx>=0.0.7
sentencepiece transformers
torch numpy
numpy

90
lora/utils.py Normal file
View File

@ -0,0 +1,90 @@
# Copyright © 2023 Apple Inc.
import glob
import json
from pathlib import Path
import mlx.core as mx
import transformers
from huggingface_hub import snapshot_download
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
repo_id=hf_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(hf_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path,
)
return weights, config.to_dict(), tokenizer
def upload_to_hub(path: str, name: str, hf_path: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def save_model(save_dir: str, weights, tokenizer, config):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
# TODO use HF file name scheme for simplicity
mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard)
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)