Merge pull request #53 from ml-explore/mistral_lora

Generalize lora finetuning to Mistral
This commit is contained in:
Awni Hannun 2023-12-09 15:05:29 -08:00 committed by GitHub
commit 3a3ea3cfb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 377 additions and 301 deletions

View File

@ -32,7 +32,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
default="bert-base-uncased", default="bert-base-uncased",
help="The huggingface name of the BERT model to save.", help="The huggingface name of the BERT model to save.",
) )
@ -44,4 +49,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
convert(args.bert_model, args.mlx_model) convert(args.bert_model, args.mlx_model)

View File

@ -24,10 +24,17 @@ def run(bert_model: str):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.") parser = argparse.ArgumentParser(
description="Run the BERT model using HuggingFace Transformers."
)
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
default="bert-base-uncased", default="bert-base-uncased",
help="The huggingface name of the BERT model to save.", help="The huggingface name of the BERT model to save.",
) )

View File

@ -1,3 +1,4 @@
import numpy as np
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from transformers import BertTokenizer from transformers import BertTokenizer
@ -214,7 +215,7 @@ def run(bert_model: str, mlx_model: str):
"A second string", "A second string",
"This is another string.", "This is another string.",
] ]
tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()} tokens = {key: mx.array(v) for key, v in tokens.items()}

View File

@ -1,7 +1,8 @@
# LoRA # LoRA
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
rank adaptation (LoRA)[^lora] for a target task. Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to generate SQL queries from natural language. However, the example is intended to
@ -15,19 +16,27 @@ Install the dependencies:
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Next, download and convert the model. If you do not have access to the model Next, download and convert the model. The Mistral weights can be downloaded with:
weights you will need to [request
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
```
If you do not have access to the Llama weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta. from Meta.
Convert the weights with: Convert the model with:
``` ```
python convert.py <path_to_torch_weights> mlx_llama_7B.npz python convert.py <path_to_torch_model> <path_to_mlx_model>
``` ```
## Run ## Run
#### Fine-tune
The main script is `lora.py`. To see a full list of options run The main script is `lora.py`. To see a full list of options run
``` ```
@ -37,28 +46,34 @@ python lora.py --help
To fine-tune a model use: To fine-tune a model use:
``` ```
python lora.py --model mlx_llama_7B.npz \ python lora.py --model <path_to_model> \
--tokenizer tokenizer.model \
--train \ --train \
--iters 600 \ --iters 600
``` ```
Note, the model path should have the MLX weights, the tokenizer, and the
`params.json` configuration which will all be output by the `convert.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter_file`. the output location with `--adapter_file`.
#### Evaluate
To compute test set perplexity use To compute test set perplexity use
``` ```
python lora.py --model mlx_llama_7B.npz \ python lora.py --model <path_to_model> \
--tokenizer tokenizer.model \ --adapter_file <path_to_adapters.npz> \
--test --test
``` ```
#### Generate
For generation use For generation use
``` ```
python lora.py --model mlx_llama_7B.npz \ python lora.py --model <path_to_model> \
--tokenizer tokenizer.model \ --adapter_file <path_to_adapters.npz> \
--num-tokens 50 \ --num-tokens 50 \
--prompt "table: 1-10015132-16 --prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
@ -81,10 +96,9 @@ training and validation loss at a few points over the course of training.
| 800 | 1.017 | 1.255 | | 800 | 1.017 | 1.255 |
| 1000 | 1.070 | 1.230 | | 1000 | 1.070 | 1.230 |
After training for 1000 iterations, the validation perplexity reduces to XX.
The model trains at around 475 tokens per second on an M2 Ultra. The model trains at around 475 tokens per second on an M2 Ultra.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. [^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@ -1,53 +1,59 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
from itertools import starmap import json
import numpy as np import numpy as np
from pathlib import Path
import shutil
import os
import torch import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser = argparse.ArgumentParser(
parser.add_argument("torch_weights") description="Convert Mistral or Llama models to MLX.",
parser.add_argument("output_file") )
parser.add_argument(
"--torch_model",
type=str,
default="mistral-7B-v0.1/",
help="The torch model directory",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mlx-mistral-7B-v0.1/",
help="The directory to store the mlx model",
)
args = parser.parse_args() args = parser.parse_args()
state = torch.load(args.torch_weights) torch_path = Path(args.torch_model)
if not os.path.exists(args.mlx_model):
os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model)
state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez( np.savez(
args.output_file, str(mlx_path / "weights.npz"),
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} **{k: v.to(torch.float16).numpy() for k, v in state.items()}
) )
# Copy the tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Copy the params
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
if "sliding_window" in config:
config.pop("sliding_window")
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
with open(mlx_path / "params.json", "w") as outfile:
json.dump(config, outfile)

View File

@ -1,199 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
input_dims, output_dims = linear.weight.shape
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
def generate(self, x, temp=1.0):
cache = []
try:
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same was as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list
cache.append(c)
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
# y now has size [1]
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
yield y
finally:
del cache
def load_model(model_path):
weights = mx.load(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
num_heads = dims // 128
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
vocab_size = weights["out_proj.weight"].shape[-1]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model

View File

@ -1,28 +1,32 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import json
import math import math
import numpy as np import numpy as np
from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
import time import time
from typing import Optional, Tuple, List
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
from mlx.utils import tree_flatten from mlx.utils import tree_map, tree_flatten, tree_unflatten
from llama import LoRALinear, load_model from models import ModelArgs, Model, LoRALinear
import wikisql import wikisql
def build_parser(): def build_parser():
parser = argparse.ArgumentParser(description="Llama LoRA finetuning") parser = argparse.ArgumentParser(
parser.add_argument( description="LoRA finetuning with Llama or Mistral"
"--model", required=True, help="The model file containing MLX weights"
) )
parser.add_argument( parser.add_argument(
"--tokenizer", required=True, help="The sentencepiece tokenizer" "--model",
required=True,
help="A path to the model files containing the tokenizer, weights, config.",
) )
# Generation args # Generation args
parser.add_argument( parser.add_argument(
@ -73,6 +77,12 @@ def build_parser():
default=200, default=200,
help="Number of training steps between validations.", help="Number of training steps between validations.",
) )
parser.add_argument(
"--resume_adapter_file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument( parser.add_argument(
"--adapter_file", "--adapter_file",
type=str, type=str,
@ -94,9 +104,30 @@ def build_parser():
return parser return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
def loss(model, inputs, targets, lengths): def loss(model, inputs, targets, lengths):
# Run model on inputs # Run model on inputs
logits = model(inputs) logits, _ = model(inputs)
# Mask padding tokens # Mask padding tokens
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
@ -117,7 +148,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
# Collect batches from dataset # Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size): for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch # Encode batch
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)]) batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
# Pad to the max length # Pad to the max length
@ -195,40 +226,56 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args): def generate(model, prompt, tokenizer, args):
# Encode prompt print(args.prompt, end="", flush=True)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)]) prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step():
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
skip = 0
prompt_processing = None
tokens = [] tokens = []
for token, _ in zip(generate_step(), range(args.num_tokens)):
# Genertation loop
start = time.perf_counter()
for token in model.generate(x, args.temp):
tokens.append(token) tokens.append(token)
if len(tokens) == 1: if (len(tokens) % 10) == 0:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = time.perf_counter() - start
if len(tokens) >= args.num_tokens:
break
if (len(tokens) % args.write_every) == 0:
mx.eval(tokens) mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s, end="", flush=True)
skip = len(s) tokens = []
mx.eval(tokens) mx.eval(tokens)
full_gen = time.perf_counter() - start
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True) print(s, flush=True)
print()
print(f"Prompt processing took: {prompt_processing:.3f} s")
print(f"Full generation took: {full_gen:.3f} s") def load_model(folder: str, dtype=mx.float32):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Model(model_args)
model.update(weights)
return model, tokenizer
if __name__ == "__main__": if __name__ == "__main__":
@ -237,17 +284,14 @@ if __name__ == "__main__":
np.random.seed(args.seed) np.random.seed(args.seed)
print("Loading tokenizer")
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("Loading pretrained model") print("Loading pretrained model")
model = load_model(args.model) model, tokenizer = load_model(args.model)
# Freeze all layers other than LORA linears # Freeze all layers other than LORA linears
model.freeze() model.freeze()
for l in model.layers[16:32]: for l in model.layers[16:32]:
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj) l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj) l.attention.wv = LoRALinear.from_linear(l.attention.wv)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M") print(f"Total parameters {p:.3f}M")
@ -257,6 +301,11 @@ if __name__ == "__main__":
print("Loading datasets") print("Loading datasets")
train_set, valid_set, test_set = wikisql.load() train_set, valid_set, test_set = wikisql.load()
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file)
if args.train: if args.train:
print("Training") print("Training")
opt = optim.Adam(learning_rate=args.learning_rate) opt = optim.Adam(learning_rate=args.learning_rate)
@ -287,5 +336,4 @@ if __name__ == "__main__":
if args.prompt is not None: if args.prompt is not None:
print("Generating") print("Generating")
generate(model, args.prompt, tokenizer, args) generate(model, args.prompt, tokenizer, args)

193
lora/models.py Normal file
View File

@ -0,0 +1,193 @@
# Copyright © 2023 Apple Inc.
from dataclasses import dataclass
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache

View File

@ -1,2 +1,3 @@
mlx
sentencepiece sentencepiece
torch torch

View File

@ -42,7 +42,7 @@ def wikitext(dataset="2", save_dir="/tmp"):
Load the WikiText-* language modeling dataset: Load the WikiText-* language modeling dataset:
https://paperswithcode.com/dataset/wikitext-2 https://paperswithcode.com/dataset/wikitext-2
https://paperswithcode.com/dataset/wikitext-103 https://paperswithcode.com/dataset/wikitext-103
""" """
if dataset not in ("2", "103"): if dataset not in ("2", "103"):
raise ValueError(f'Dataset must be either "2" or "103", got {dataset}') raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')