Move lora example to use the same model format / conversion as hf_llm (#252)

* huffing face the lora example to allow more models

* fixes

* comments

* more readme nits

* fusion + works better for qlora

* nits'

* comments
This commit is contained in:
Awni Hannun 2024-01-09 11:14:52 -08:00 committed by GitHub
parent bbd7172eef
commit 7b258f33ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 521 additions and 224 deletions

View File

@ -60,7 +60,7 @@ You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face).
For example, to make 4-bit quantized a model, run:
For example, to make a 4-bit quantized model, run:
```
python convert.py --hf-path <hf_repo> -q
@ -73,5 +73,5 @@ python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
to `convert.py`.

View File

@ -39,7 +39,6 @@ def generate(
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)

View File

@ -10,7 +10,6 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@ -250,9 +249,7 @@ def load(path_or_hf_repo: str):
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer

View File

@ -1,8 +1,9 @@
# Fine-Tuning with LoRA or QLoRA
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task. The example also supports quantized LoRA (QLoRA).[^qlora]
This is an example of using MLX to fine-tune an LLM with low rank adaptation
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
(QLoRA).[^qlora] The example works with Llama and Mistral style
models available on Hugging Face.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
@ -11,11 +12,13 @@ be general should you wish to use a custom dataset.
## Contents
* [Setup](#Setup)
* [Convert](#convert)
* [Run](#Run)
* [Fine-tune](#Fine-tune)
* [Evaluate](#Evaluate)
* [Generate](#Generate)
* [Results](#Results)
* [Fuse and Upload](#Fuse-and-Upload)
* [Custom Data](#Custom-Data)
* [Memory Issues](#Memory-Issues)
@ -28,36 +31,49 @@ Install the dependencies:
pip install -r requirements.txt
```
Next, download and convert the model. The Mistral weights can be downloaded with:
### Convert
This step is optional if you want to quantize (for QLoRA) or change the default
data type of a pre-existing model.
You convert models using the `convert.py` script. This script takes a Hugging
Face repo as input and outputs a model directory (which you can optionally also
upload to Hugging Face).
To make a 4-bit quantized model, run:
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
python convert.py --hf-path <hf_repo> -q
```
If you do not have access to the Llama weights you will need to [request
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
from Meta.
Convert the model with:
For example, the following will make a 4-bit quantized Mistral 7B and by default
store it in `mlx_model`:
```
python convert.py \
--torch-path <path_to_torch_model> \
--mlx-path <path_to_mlx_model>
python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q
```
If you wish to use QLoRA, then convert the model with 4-bit quantization using
the `-q` option.
For more options run:
```
python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name`
to `convert.py`.
## Run
The main script is `lora.py`. To see a full list of options run
The main script is `lora.py`. To see a full list of options run:
```
python lora.py --help
```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted mdoel.
### Fine-tune
To fine-tune a model use:
@ -71,9 +87,6 @@ python lora.py --model <path_to_model> \
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
Note, the model path should have the MLX weights, the tokenizer, and the
`config.json` which will all be output by the `convert.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter-file`.
@ -82,7 +95,7 @@ You can resume fine-tuning with an existing adapter with `--resume-adapter-file
### Evaluate
To compute test set perplexity use
To compute test set perplexity use:
```
python lora.py --model <path_to_model> \
@ -92,7 +105,7 @@ python lora.py --model <path_to_model> \
### Generate
For generation use
For generation use:
```
python lora.py --model <path_to_model> \
@ -121,6 +134,37 @@ training and validation loss at a few points over the course of training.
The model trains at around 475 tokens per second on an M2 Ultra.
## Fuse and Upload
You can generate a fused model with the low-rank adapters included using the
`fuse.py` script. This script also optionally allows you to upload the fused
model to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community).
To generate the fused model run:
```
python fuse.py
```
This will by default load the base model from `mlx_model/`, the adapters from
`adapters.npz`, and save the fused model in the path `lora_fused_model/`. All
of these are configurable. You can see the list of options with:
```
python fuse.py --help
```
To upload a fused model, supply the `--upload-name` and `--hf-path` arguments
to `fuse.py`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```
python fuse.py --upload My-4-bit-model --hf-repo mistralai/Mistral-7B-v0.1
```
## Custom Data
You can make your own dataset for fine-tuning with LoRA. You can specify the
@ -164,7 +208,7 @@ For example, for a machine with 32 GB the following should run reasonably fast:
```
python lora.py \
--model <path_to_model> \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4
@ -175,6 +219,4 @@ The above command on an M1 Max with 32 GB runs at about 250 tokens-per-second.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@ -2,33 +2,23 @@
import argparse
import copy
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from lora import Model, ModelArgs
import utils
from mlx.utils import tree_flatten
from models import Model, ModelArgs
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Model(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
model = Model(ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
)
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
@ -42,19 +32,18 @@ def quantize(weights, config, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Mistral or Llama models to MLX.",
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--torch-path",
"--hf-path",
type=str,
default="mistral-7B-v0.1/",
help="Path to the torch model directory",
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model/",
help="The directory to store the mlx model",
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
@ -74,50 +63,31 @@ if __name__ == "__main__":
type=int,
default=4,
)
args = parser.parse_args()
args = parser.parse_args()
torch_path = Path(args.torch_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Copy the tokenizer
tokenizer_path = torch_path / "tokenizer.model"
if not tokenizer_path.exists():
print(f"Make sure there is a file tokenizer.model in {args.torch_path}")
exit(0)
shutil.copyfile(
str(tokenizer_path),
str(mlx_path / "tokenizer.model"),
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
# Load the torch model weights to numpy:
weights = torch.load(str(torch_path / "consolidated.00.pth"))
for k, v in weights.items():
weights[k] = v.to(torch.float16).numpy()
args = parser.parse_args()
# Standardize the params
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
unused = ["multiple_of", "sliding_window"]
for k in unused:
config.pop(k, None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[0]
print("[INFO] Loading")
weights, config, tokenizer = utils.fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
np.savez(str(mlx_path / "weights.npz"), **weights)
with open(mlx_path / "config.json", "w") as outfile:
json.dump(config, outfile, indent=4)
utils.save_model(args.mlx_path, weights, tokenizer, config)
if args.upload_name is not None:
utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path)

80
lora/fuse.py Normal file
View File

@ -0,0 +1,80 @@
# Copyright © 2023 Apple Inc.
import argparse
from pathlib import Path
import mlx.core as mx
import models
import utils
from mlx.utils import tree_flatten, tree_unflatten
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-file",
type=str,
default="adapters.npz",
help="Path to the trained adapter weights (npz or safetensors).",
)
parser.add_argument(
"--hf-path",
help=(
"Path to the original Hugging Face model. This is "
"required for upload if --model is a local directory."
),
type=str,
default=None,
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
print("Loading pretrained model")
args = parser.parse_args()
model, tokenizer, config = models.load(args.model)
# Load adapters and get number of LoRA layers
adapters = list(mx.load(args.adapter_file).items())
lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]])
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[-lora_layers:]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
model.update(tree_unflatten(adapters))
fused_linears = [
(n, m.to_linear())
for n, m in model.named_modules()
if isinstance(m, models.LoRALinear)
]
model.update_modules(tree_unflatten(fused_linears))
weights = dict(tree_flatten(model.parameters()))
utils.save_model(args.save_path, weights, tokenizer._tokenizer, config)
if args.upload_name is not None:
hf_path = args.hf_path
if not Path(args.model).exists():
# If the model path doesn't exist, assume it's an HF repo
hf_path = args.model
elif hf_path is None:
raise ValueError(
"Must provide original Hugging Face repo to upload local model."
)
utils.upload_to_hub(args.save_path, args.upload_name, hf_path)

View File

@ -5,15 +5,13 @@ import json
import math
import time
from pathlib import Path
from typing import List
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import models
import numpy as np
from mlx.utils import tree_flatten, tree_unflatten
from models import LoRALinear, Model, ModelArgs
from sentencepiece import SentencePieceProcessor
def build_parser():
@ -21,7 +19,7 @@ def build_parser():
parser.add_argument(
"--model",
default="mlx_model",
help="A path to the model files containing the tokenizer, weights, config.",
help="The path to the local model directory or Hugging Face repo.",
)
# Generation args
parser.add_argument(
@ -111,34 +109,6 @@ def build_parser():
return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str, eos: bool = False) -> List[int]:
toks = [self._model.bos_id(), *self._model.encode(s)]
if eos:
toks.append(self.eos_id)
return toks
@property
def eos_id(self) -> int:
return self._model.eos_id()
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
@ -295,56 +265,27 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step():
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
prompt = tokenizer.encode(args.prompt)
tokens = []
for token, _ in zip(generate_step(), range(args.num_tokens)):
tokens.append(token)
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
def load_model(folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model.update(weights)
return model, tokenizer
tokens.append(token.item())
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
print("=" * 10)
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
if __name__ == "__main__":
@ -354,13 +295,13 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load_model(args.model)
model, tokenizer, _ = models.load(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[-args.lora_layers :]:
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
l.self_attn.q_proj = models.LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = models.LoRALinear.from_linear(l.self_attn.v_proj)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")

View File

@ -1,23 +1,81 @@
# Copyright © 2023 Apple Inc.
import glob
import inspect
import json
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
import numpy as np
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Tokenizer:
def __init__(self, model_path: str):
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._eos = self._tokenizer.eos_token_id
self._bos = self._tokenizer.bos_token_id
def encode(self, s: str, eos: bool = False) -> mx.array:
toks = self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
if eos:
toks = np.concatenate([toks, [self._eos]])
return mx.array(toks)
@property
def eos_id(self) -> int:
return self._eos
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t)
class LoRALinear(nn.Module):
@ -32,14 +90,58 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
def to_linear(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
self,
input_dims: int,
output_dims: int,
lora_rank: int = 8,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
@ -55,7 +157,7 @@ class LoRALinear(nn.Module):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
return y + self.scale * z
class RMSNorm(nn.Module):
@ -75,20 +177,31 @@ class RMSNorm(nn.Module):
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
self.repeats = n_heads // n_kv_heads
self.scale = self.args.head_dim**-0.5
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
@ -98,7 +211,7 @@ class Attention(nn.Module):
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
@ -127,30 +240,29 @@ class Attention(nn.Module):
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
return self.o_proj(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
@ -159,31 +271,32 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class Model(nn.Module):
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
@ -196,4 +309,70 @@ class Model(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
return model, Tokenizer(model_path), config
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

View File

@ -1,4 +1,3 @@
mlx>=0.0.7
sentencepiece
torch
transformers
numpy

90
lora/utils.py Normal file
View File

@ -0,0 +1,90 @@
# Copyright © 2023 Apple Inc.
import glob
import json
from pathlib import Path
import mlx.core as mx
import transformers
from huggingface_hub import snapshot_download
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
repo_id=hf_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(hf_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path,
)
return weights, config.to_dict(), tokenizer
def upload_to_hub(path: str, name: str, hf_path: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def save_model(save_dir: str, weights, tokenizer, config):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
# TODO use HF file name scheme for simplicity
mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard)
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)