mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Merge pull request #53 from ml-explore/mistral_lora
Generalize lora finetuning to Mistral
This commit is contained in:
commit
3a3ea3cfb0
@ -32,7 +32,12 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--bert-model",
|
||||
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
|
||||
choices=[
|
||||
"bert-base-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-uncased",
|
||||
"bert-large-cased",
|
||||
],
|
||||
default="bert-base-uncased",
|
||||
help="The huggingface name of the BERT model to save.",
|
||||
)
|
||||
@ -44,4 +49,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.bert_model, args.mlx_model)
|
||||
convert(args.bert_model, args.mlx_model)
|
||||
|
@ -24,10 +24,17 @@ def run(bert_model: str):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the BERT model using HuggingFace Transformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert-model",
|
||||
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
|
||||
choices=[
|
||||
"bert-base-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-uncased",
|
||||
"bert-large-cased",
|
||||
],
|
||||
default="bert-base-uncased",
|
||||
help="The huggingface name of the BERT model to save.",
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from transformers import BertTokenizer
|
||||
@ -214,7 +215,7 @@ def run(bert_model: str, mlx_model: str):
|
||||
"A second string",
|
||||
"This is another string.",
|
||||
]
|
||||
|
||||
|
||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# LoRA
|
||||
|
||||
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low
|
||||
rank adaptation (LoRA)[^lora] for a target task.
|
||||
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
|
||||
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
|
||||
task.
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
@ -15,19 +16,27 @@ Install the dependencies:
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to [request
|
||||
Next, download and convert the model. The Mistral weights can be downloaded with:
|
||||
|
||||
```
|
||||
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
||||
tar -xf mistral-7B-v0.1.tar
|
||||
```
|
||||
|
||||
If you do not have access to the Llama weights you will need to [request
|
||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
from Meta.
|
||||
|
||||
Convert the weights with:
|
||||
Convert the model with:
|
||||
|
||||
```
|
||||
python convert.py <path_to_torch_weights> mlx_llama_7B.npz
|
||||
python convert.py <path_to_torch_model> <path_to_mlx_model>
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
#### Fine-tune
|
||||
|
||||
The main script is `lora.py`. To see a full list of options run
|
||||
|
||||
```
|
||||
@ -37,28 +46,34 @@ python lora.py --help
|
||||
To fine-tune a model use:
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--train \
|
||||
--iters 600 \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
Note, the model path should have the MLX weights, the tokenizer, and the
|
||||
`params.json` configuration which will all be output by the `convert.py` script.
|
||||
|
||||
By default, the adapter weights are saved in `adapters.npz`. You can specify
|
||||
the output location with `--adapter_file`.
|
||||
|
||||
#### Evaluate
|
||||
|
||||
To compute test set perplexity use
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--adapter_file <path_to_adapters.npz> \
|
||||
--test
|
||||
```
|
||||
|
||||
#### Generate
|
||||
|
||||
For generation use
|
||||
|
||||
```
|
||||
python lora.py --model mlx_llama_7B.npz \
|
||||
--tokenizer tokenizer.model \
|
||||
python lora.py --model <path_to_model> \
|
||||
--adapter_file <path_to_adapters.npz> \
|
||||
--num-tokens 50 \
|
||||
--prompt "table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
@ -81,10 +96,9 @@ training and validation loss at a few points over the course of training.
|
||||
| 800 | 1.017 | 1.255 |
|
||||
| 1000 | 1.070 | 1.230 |
|
||||
|
||||
After training for 1000 iterations, the validation perplexity reduces to XX.
|
||||
|
||||
The model trains at around 475 tokens per second on an M2 Ultra.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
||||
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
||||
|
@ -1,53 +1,59 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from itertools import starmap
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
def map_torch_to_mlx(key, value):
|
||||
if "tok_embedding" in key:
|
||||
key = "embedding.weight"
|
||||
|
||||
elif "norm" in key:
|
||||
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
|
||||
|
||||
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
|
||||
key = key.replace("wq", "query_proj")
|
||||
key = key.replace("wk", "key_proj")
|
||||
key = key.replace("wv", "value_proj")
|
||||
key = key.replace("wo", "out_proj")
|
||||
|
||||
elif "w1" in key or "w2" in key or "w3" in key:
|
||||
# The FFN is a separate submodule in PyTorch
|
||||
key = key.replace("feed_forward.w1", "linear1")
|
||||
key = key.replace("feed_forward.w3", "linear2")
|
||||
key = key.replace("feed_forward.w2", "linear3")
|
||||
|
||||
elif "output" in key:
|
||||
key = key.replace("output", "out_proj")
|
||||
|
||||
elif "rope" in key:
|
||||
return None, None
|
||||
|
||||
return (
|
||||
key,
|
||||
value.numpy()
|
||||
if value.dtype != torch.bfloat16
|
||||
else value.to(torch.float32).numpy(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument("torch_weights")
|
||||
parser.add_argument("output_file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Mistral or Llama models to MLX.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--torch_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The torch model directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_model",
|
||||
type=str,
|
||||
default="mlx-mistral-7B-v0.1/",
|
||||
help="The directory to store the mlx model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_weights)
|
||||
torch_path = Path(args.torch_model)
|
||||
if not os.path.exists(args.mlx_model):
|
||||
os.makedirs(args.mlx_model)
|
||||
mlx_path = Path(args.mlx_model)
|
||||
|
||||
state = torch.load(str(torch_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
args.output_file,
|
||||
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
|
||||
str(mlx_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
|
||||
# Copy the tokenizer
|
||||
shutil.copyfile(
|
||||
str(torch_path / "tokenizer.model"),
|
||||
str(mlx_path / "tokenizer.model"),
|
||||
)
|
||||
|
||||
# Copy the params
|
||||
with open(torch_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
if "sliding_window" in config:
|
||||
config.pop("sliding_window")
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
|
||||
with open(mlx_path / "params.json", "w") as outfile:
|
||||
json.dump(config, outfile)
|
||||
|
199
lora/llama.py
199
lora/llama.py
@ -1,199 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||
input_dims, output_dims = linear.weight.shape
|
||||
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, lora_rank),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x)
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + 2.0 * z
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.layers = [
|
||||
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(dims)
|
||||
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
|
||||
x = self.embedding(x)
|
||||
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return self.out_proj(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
try:
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.embedding.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same was as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.embedding(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
# We store the per layer cache in a simple python list
|
||||
cache.append(c)
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.embedding(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.out_proj(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
yield y
|
||||
|
||||
finally:
|
||||
del cache
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
weights = mx.load(model_path)
|
||||
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
|
||||
num_heads = dims // 128
|
||||
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
|
||||
vocab_size = weights["out_proj.weight"].shape[-1]
|
||||
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
mx.eval(model.parameters())
|
||||
return model
|
128
lora/lora.py
128
lora/lora.py
@ -1,28 +1,32 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.utils import tree_map, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
from llama import LoRALinear, load_model
|
||||
from models import ModelArgs, Model, LoRALinear
|
||||
import wikisql
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
|
||||
parser.add_argument(
|
||||
"--model", required=True, help="The model file containing MLX weights"
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LoRA finetuning with Llama or Mistral"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer", required=True, help="The sentencepiece tokenizer"
|
||||
"--model",
|
||||
required=True,
|
||||
help="A path to the model files containing the tokenizer, weights, config.",
|
||||
)
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
@ -73,6 +77,12 @@ def build_parser():
|
||||
default=200,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_adapter_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Load path to resume training with the given adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter_file",
|
||||
type=str,
|
||||
@ -94,9 +104,30 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
assert Path(model_path).exists(), model_path
|
||||
self._model = SentencePieceProcessor(model_file=model_path)
|
||||
self._sep = "▁"
|
||||
assert self._model.vocab_size() == self._model.get_piece_size()
|
||||
|
||||
def encode(self, s: str) -> List[int]:
|
||||
return [self._model.bos_id(), *self._model.encode(s)]
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
out = self._model.decode(t)
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._model.vocab_size()
|
||||
|
||||
|
||||
def loss(model, inputs, targets, lengths):
|
||||
# Run model on inputs
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
# Mask padding tokens
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
@ -117,7 +148,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
|
||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Pad to the max length
|
||||
@ -195,40 +226,56 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
|
||||
|
||||
def generate(model, prompt, tokenizer, args):
|
||||
# Encode prompt
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
|
||||
def generate_step():
|
||||
temp = args.temp
|
||||
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt[None])
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
|
||||
# Genertation loop
|
||||
start = time.perf_counter()
|
||||
for token in model.generate(x, args.temp):
|
||||
for token, _ in zip(generate_step(), range(args.num_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = time.perf_counter() - start
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
break
|
||||
|
||||
if (len(tokens) % args.write_every) == 0:
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = time.perf_counter() - start
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print(f"Prompt processing took: {prompt_processing:.3f} s")
|
||||
print(f"Full generation took: {full_gen:.3f} s")
|
||||
print(s, flush=True)
|
||||
|
||||
|
||||
def load_model(folder: str, dtype=mx.float32):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
model_args = ModelArgs(**config)
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = tokenizer.vocab_size
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Model(model_args)
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -237,17 +284,14 @@ if __name__ == "__main__":
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading tokenizer")
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model = load_model(args.model)
|
||||
model, tokenizer = load_model(args.model)
|
||||
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.layers[16:32]:
|
||||
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
|
||||
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
|
||||
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
|
||||
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
|
||||
|
||||
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
print(f"Total parameters {p:.3f}M")
|
||||
@ -257,6 +301,11 @@ if __name__ == "__main__":
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = wikisql.load()
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file)
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
@ -287,5 +336,4 @@ if __name__ == "__main__":
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
|
||||
generate(model, args.prompt, tokenizer, args)
|
||||
|
193
lora/models.py
Normal file
193
lora/models.py
Normal file
@ -0,0 +1,193 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, lora_rank),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x)
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + 2.0 * z
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h)), cache
|
@ -1,2 +1,3 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
||||
|
@ -42,7 +42,7 @@ def wikitext(dataset="2", save_dir="/tmp"):
|
||||
Load the WikiText-* language modeling dataset:
|
||||
https://paperswithcode.com/dataset/wikitext-2
|
||||
https://paperswithcode.com/dataset/wikitext-103
|
||||
|
||||
|
||||
"""
|
||||
if dataset not in ("2", "103"):
|
||||
raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')
|
||||
|
Loading…
Reference in New Issue
Block a user