Merge branch 'ml-explore:main' into main

This commit is contained in:
Pawel Kowalski 2023-12-15 18:32:31 +01:00 committed by GitHub
commit 4c4317feda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 798 additions and 121 deletions

View File

@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples include:
- [Transformer language model](transformer_lm) training.
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral).
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
- Parameter efficient fine-tuning with [LoRA](lora).
- Generating images with [Stable Diffusion](stable_diffusion).

View File

@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex
```
python convert.py \
--bert-model bert-base-uncased
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```

View File

@ -8,7 +8,6 @@ import mlx.core as mx
import mlx.nn as nn
import argparse
import numpy
import math
@dataclass
@ -35,74 +34,6 @@ model_configs = {
}
class MultiHeadAttention(nn.Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = nn.Linear(query_input_dims, dims, True)
self.key_proj = nn.Linear(key_input_dims, dims, True)
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.convert_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
def convert_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
@ -117,7 +48,7 @@ class TransformerEncoderLayer(nn.Module):
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
@ -187,9 +118,15 @@ class Bert(nn.Module):
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: Optional[mx.array] = None,
attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
# convert 0's to -infs, 1's to 0's, and make it broadcastable
attention_mask = mx.log(attention_mask)
attention_mask = mx.expand_dims(attention_mask, (1, 2))
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))

View File

@ -1,3 +1,3 @@
mlx
mlx>=0.0.5
transformers
numpy

51
cifar/README.md Normal file
View File

@ -0,0 +1,51 @@
# CIFAR and ResNets
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
configurations in accordance with the original
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
load the dataset.
## Pre-requisites
Install the dependencies:
```
pip install -r requirements.txt
```
## Running the example
Run the example with:
```
python main.py
```
By default the example runs on the GPU. To run on the CPU, use:
```
python main.py --cpu
```
For all available options, run:
```
python main.py --help
```
## Results
After training with the default `resnet20` architecture for 100 epochs, you
should see the following results:
```
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
Epoch: 99 | Test acc 0.807
```
Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
or a `BatchNorm` layer. We intend to update this example once these features
are added.

30
cifar/dataset.py Normal file
View File

@ -0,0 +1,30 @@
import mlx.core as mx
from mlx.data.datasets import load_cifar10
import math
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
def normalize(x):
x = x.astype("float32") / 255.0
return (x - mean) / std
tr_iter = (
tr.shuffle()
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
.pad("image", 1, 4, 4, 0.0)
.image_random_crop("image", 32, 32)
.key_transform("image", normalize)
.batch(batch_size)
)
test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
return tr_iter, test_iter

120
cifar/main.py Normal file
View File

@ -0,0 +1,120 @@
import argparse
import time
import resnet
import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from dataset import get_cifar10
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--arch",
type=str,
default="resnet20",
choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
help="model architecture",
)
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = []
accs = []
samples_per_sec = []
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 10 == 0:
print(
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
)
)
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter, epoch):
accs = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
def main(args):
mx.random.seed(args.seed)
model = getattr(resnet, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
)
)
)
test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
train_data.reset()
test_data.reset()
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

2
cifar/requirements.txt Normal file
View File

@ -0,0 +1,2 @@
mlx
mlx-data

131
cifar/resnet.py Normal file
View File

@ -0,0 +1,131 @@
"""
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = [
"ResNet",
"resnet20",
"resnet32",
"resnet44",
"resnet56",
"resnet110",
"resnet1202",
]
class ShortcutA(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def __call__(self, x):
return mx.pad(
x[:, ::2, ::2, :],
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
)
class Block(nn.Module):
"""
Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
"""
def __init__(self, in_dims, dims, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.LayerNorm(dims)
self.conv2 = nn.Conv2d(
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.LayerNorm(dims)
if stride != 1:
self.shortcut = ShortcutA(dims)
else:
self.shortcut = None
def __call__(self, x):
out = nn.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.shortcut is None:
out += x
else:
out += self.shortcut(x)
out = nn.relu(out)
return out
class ResNet(nn.Module):
"""
Creates a ResNet model for CIFAR-10, as specified in the original paper.
"""
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.LayerNorm(16)
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(in_dims, dims, stride))
in_dims = dims
return nn.Sequential(*layers)
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
x = nn.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
x = self.linear(x)
return x
def resnet20(**kwargs):
return ResNet(Block, [3, 3, 3], **kwargs)
def resnet32(**kwargs):
return ResNet(Block, [5, 5, 5], **kwargs)
def resnet44(**kwargs):
return ResNet(Block, [7, 7, 7], **kwargs)
def resnet56(**kwargs):
return ResNet(Block, [9, 9, 9], **kwargs)
def resnet110(**kwargs):
return ResNet(Block, [18, 18, 18], **kwargs)
def resnet1202(**kwargs):
return ResNet(Block, [200, 200, 200], **kwargs)

View File

@ -315,7 +315,7 @@ def load_model(model_path):
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplie"]
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
for k in unused:
if k in config:
config.pop(k)

View File

@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar
```
If you do not have access to the Llama weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
from Meta.
Convert the model with:

View File

@ -2,6 +2,8 @@
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
### Setup
@ -16,37 +18,56 @@ brew install git-lfs
Download the models from Hugging Face:
For the base model use:
```
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
```
After that's done, combine the files:
For the instruction fine-tuned model use:
```
cd mixtral-8x7b-32kseqlen/
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
```
Then run:
```
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
cd $MIXTRAL_MODEL/ && \
git lfs pull --include "consolidated.*.pt" && \
git lfs pull --include "tokenizer.model"
```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them:
```
python convert.py --model_path mixtral-8x7b-32kseqlen/
python convert.py --model_path $MIXTRAL_MODEL/
```
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate
As easy as:
```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
python mixtral.py --model_path $MIXTRAL_MODEL/
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
For more options including how to prompt the model, run:
```
python mixtral.py --help
```
For the Instruction model, make sure to follow the prompt format:
```
[INST] Instruction prompt [/INST]
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
details

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc.
import argparse
import glob
import json
import numpy as np
from pathlib import Path
import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen/",
default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
with open("params.json") as fid:
args = json.load(fid)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse
from dataclasses import dataclass
import glob
import json
import numpy as np
from pathlib import Path
@ -40,6 +41,26 @@ class RMSNorm(nn.Module):
return self.weight * output
class RoPE(nn.RoPE):
def __init__(self, dims: int, traditional: bool = False):
super().__init__(dims, traditional)
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@ -56,7 +77,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
self.rope = RoPE(args.head_dim, traditional=True)
def __call__(
self,
@ -125,7 +146,10 @@ class MOEFeedForward(nn.Module):
gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
@ -181,8 +205,9 @@ class Mixtral(nn.Module):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
@ -191,7 +216,7 @@ class Mixtral(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
return self.output(self.norm(h[:, T - 1 : T, :])), cache
class Tokenizer:
@ -222,10 +247,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
with open("params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args)
@ -255,7 +283,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model_path",
type=str,
default="mixtral-8x7b-32kseqlen",
default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config",
)
parser.add_argument(
@ -274,7 +302,7 @@ if __name__ == "__main__":
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}

1
phi2/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
weights.npz

57
phi2/README.md Normal file
View File

@ -0,0 +1,57 @@
# Phi-2
Phi-2 is a 2.7B parameter language model released by Microsoft with
performance that rivals much larger models.[^1] It was trained on a mixture of
GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
precision.
## Setup
Download and convert the model:
```sh
python convert.py
```
This will make the `weights.npz` file which MLX can read.
## Generate
To generate text with the default prompt:
```sh
python phi2.py
```
Should give the output:
```
Answer: Mathematics is like a lighthouse that guides us through the darkness of
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
provides us with a clear path to navigate through complex problems. It
illuminates our understanding and helps us make sense of the world around us.
Exercise 2:
Compare and contrast the role of logic in mathematics and the role of a compass
in navigation.
Answer: Logic in mathematics is like a compass in navigation. It helps
```
To use your own prompt:
```sh
python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate>
```
To see a list of options run:
```sh
python phi2.py --help
```
[^1]: For more details on the model see the [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)

24
phi2/convert.py Normal file
View File

@ -0,0 +1,24 @@
from transformers import AutoModelForCausalLM
import numpy as np
def replace_key(key: str) -> str:
if "wte.weight" in key:
key = "wte.weight"
if ".mlp" in key:
key = key.replace(".mlp", "")
return key
def convert():
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
if __name__ == "__main__":
convert()

222
phi2/phi2.py Normal file
View File

@ -0,0 +1,222 @@
import argparse
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
import mlx.core as mx
import mlx.nn as nn
import math
@dataclass
class ModelArgs:
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.fc1 = nn.Linear(dims, mlp_dims)
self.fc2 = nn.Linear(mlp_dims, dims)
self.act = nn.GELU(approx="precise")
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.fc2(self.act(self.fc1(h)))
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Phi2(nn.Module):
def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model():
model = Phi2(ModelArgs())
weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model()
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print("[INFO] Generating with Phi-2...", flush=True)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

5
phi2/requirements.txt Normal file
View File

@ -0,0 +1,5 @@
einops
mlx
numpy
transformers
torch

View File

@ -27,7 +27,7 @@ Usage
------
Although each component in this repository can be used by itself, the fastest
way to get started is by using the `StableDiffusion` class from the `diffusion`
way to get started is by using the `StableDiffusion` class from the `stable_diffusion`
module.
```python

View File

@ -1,7 +1,7 @@
# whisper
# Whisper
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from Open AI, ranging from 39 million to 1.5 billion
recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1].
### Setup
@ -15,7 +15,7 @@ pip install -r requirements.txt
Install [`ffmpeg`](https://ffmpeg.org/):
```
# on MacOS using Homebrew (https://brew.sh/)
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```

View File

@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)

View File

@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}
@ -166,7 +166,8 @@ def convert(model, rules=None):
def torch_to_mlx(
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
@ -194,6 +195,6 @@ def torch_to_mlx(
def load_model(
name: str,
download_root: str = None,
dtype : mx.Dtype = mx.float32,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@ -43,7 +43,7 @@ class ModelHolder:
model_name = None
@classmethod
def get_model(cls, model: str, dtype : mx.Dtype):
def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype)
cls.model_name = model

View File

@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
@ -117,13 +118,19 @@ class ResidualAttentionBlock(nn.Module):
if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv)
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -134,8 +141,8 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state)
def __call__(self, x):
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
x = nn.gelu(self.conv1(x)).astype(x.dtype)
x = nn.gelu(self.conv2(x)).astype(x.dtype)
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
self,
n_vocab: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
):
super().__init__()
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
for _ in range(n_layer)
]
self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype
)
def __call__(self, x, xa, kv_cache=None):
"""