Merge branch 'ml-explore:main' into main

This commit is contained in:
Pawel Kowalski 2023-12-15 18:32:31 +01:00 committed by GitHub
commit 4c4317feda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 798 additions and 121 deletions

View File

@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples include: Some more useful examples include:
- [Transformer language model](transformer_lm) training. - [Transformer language model](transformer_lm) training.
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral). - Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) - Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
- Parameter efficient fine-tuning with [LoRA](lora). - Parameter efficient fine-tuning with [LoRA](lora).
- Generating images with [Stable Diffusion](stable_diffusion). - Generating images with [Stable Diffusion](stable_diffusion).

View File

@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex
``` ```
python convert.py \ python convert.py \
--bert-model bert-base-uncased --bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz --mlx-model weights/bert-base-uncased.npz
``` ```

View File

@ -8,7 +8,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import argparse import argparse
import numpy import numpy
import math
@dataclass @dataclass
@ -35,74 +34,6 @@ model_configs = {
} }
class MultiHeadAttention(nn.Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = nn.Linear(query_input_dims, dims, True)
self.key_proj = nn.Linear(key_input_dims, dims, True)
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.convert_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
def convert_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
""" """
A transformer encoder layer with (the original BERT) post-normalization. A transformer encoder layer with (the original BERT) post-normalization.
@ -117,7 +48,7 @@ class TransformerEncoderLayer(nn.Module):
): ):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads) self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims) self.linear1 = nn.Linear(dims, mlp_dims)
@ -187,9 +118,15 @@ class Bert(nn.Module):
self, self,
input_ids: mx.array, input_ids: mx.array,
token_type_ids: mx.array, token_type_ids: mx.array,
attention_mask: Optional[mx.array] = None, attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids) x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
# convert 0's to -infs, 1's to 0's, and make it broadcastable
attention_mask = mx.log(attention_mask)
attention_mask = mx.expand_dims(attention_mask, (1, 2))
y = self.encoder(x, attention_mask) y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0])) return y, mx.tanh(self.pooler(y[:, 0]))

View File

@ -1,3 +1,3 @@
mlx mlx>=0.0.5
transformers transformers
numpy numpy

51
cifar/README.md Normal file
View File

@ -0,0 +1,51 @@
# CIFAR and ResNets
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
configurations in accordance with the original
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
load the dataset.
## Pre-requisites
Install the dependencies:
```
pip install -r requirements.txt
```
## Running the example
Run the example with:
```
python main.py
```
By default the example runs on the GPU. To run on the CPU, use:
```
python main.py --cpu
```
For all available options, run:
```
python main.py --help
```
## Results
After training with the default `resnet20` architecture for 100 epochs, you
should see the following results:
```
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
Epoch: 99 | Test acc 0.807
```
Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
or a `BatchNorm` layer. We intend to update this example once these features
are added.

30
cifar/dataset.py Normal file
View File

@ -0,0 +1,30 @@
import mlx.core as mx
from mlx.data.datasets import load_cifar10
import math
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
def normalize(x):
x = x.astype("float32") / 255.0
return (x - mean) / std
tr_iter = (
tr.shuffle()
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
.pad("image", 1, 4, 4, 0.0)
.image_random_crop("image", 32, 32)
.key_transform("image", normalize)
.batch(batch_size)
)
test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
return tr_iter, test_iter

120
cifar/main.py Normal file
View File

@ -0,0 +1,120 @@
import argparse
import time
import resnet
import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from dataset import get_cifar10
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--arch",
type=str,
default="resnet20",
choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
help="model architecture",
)
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = []
accs = []
samples_per_sec = []
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 10 == 0:
print(
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
)
)
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter, epoch):
accs = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
def main(args):
mx.random.seed(args.seed)
model = getattr(resnet, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
)
)
)
test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
train_data.reset()
test_data.reset()
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

2
cifar/requirements.txt Normal file
View File

@ -0,0 +1,2 @@
mlx
mlx-data

131
cifar/resnet.py Normal file
View File

@ -0,0 +1,131 @@
"""
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = [
"ResNet",
"resnet20",
"resnet32",
"resnet44",
"resnet56",
"resnet110",
"resnet1202",
]
class ShortcutA(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def __call__(self, x):
return mx.pad(
x[:, ::2, ::2, :],
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
)
class Block(nn.Module):
"""
Implements a ResNet block with two convolutional layers and a skip connection.
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
"""
def __init__(self, in_dims, dims, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.LayerNorm(dims)
self.conv2 = nn.Conv2d(
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.LayerNorm(dims)
if stride != 1:
self.shortcut = ShortcutA(dims)
else:
self.shortcut = None
def __call__(self, x):
out = nn.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.shortcut is None:
out += x
else:
out += self.shortcut(x)
out = nn.relu(out)
return out
class ResNet(nn.Module):
"""
Creates a ResNet model for CIFAR-10, as specified in the original paper.
"""
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.LayerNorm(16)
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(in_dims, dims, stride))
in_dims = dims
return nn.Sequential(*layers)
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
x = nn.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
x = self.linear(x)
return x
def resnet20(**kwargs):
return ResNet(Block, [3, 3, 3], **kwargs)
def resnet32(**kwargs):
return ResNet(Block, [5, 5, 5], **kwargs)
def resnet44(**kwargs):
return ResNet(Block, [7, 7, 7], **kwargs)
def resnet56(**kwargs):
return ResNet(Block, [9, 9, 9], **kwargs)
def resnet110(**kwargs):
return ResNet(Block, [18, 18, 18], **kwargs)
def resnet1202(**kwargs):
return ResNet(Block, [200, 200, 200], **kwargs)

View File

@ -315,7 +315,7 @@ def load_model(model_path):
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0: if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1] config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplie"] unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
for k in unused: for k in unused:
if k in config: if k in config:
config.pop(k) config.pop(k)

View File

@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar
``` ```
If you do not have access to the Llama weights you will need to [request If you do not have access to the Llama weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
from Meta. from Meta.
Convert the model with: Convert the model with:

View File

@ -2,6 +2,8 @@
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
### Setup ### Setup
@ -16,37 +18,56 @@ brew install git-lfs
Download the models from Hugging Face: Download the models from Hugging Face:
For the base model use:
``` ```
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
``` ```
After that's done, combine the files: For the instruction fine-tuned model use:
``` ```
cd mixtral-8x7b-32kseqlen/ export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth ```
Then run:
```
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
cd $MIXTRAL_MODEL/ && \
git lfs pull --include "consolidated.*.pt" && \
git lfs pull --include "tokenizer.model"
``` ```
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
MLX can read them: MLX can read them:
``` ```
python convert.py --model_path mixtral-8x7b-32kseqlen/ python convert.py --model_path $MIXTRAL_MODEL/
``` ```
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate ### Generate
As easy as: As easy as:
``` ```
python mixtral.py --model_path mixtral-8x7b-32kseqlen/ python mixtral.py --model_path $MIXTRAL_MODEL/
``` ```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. For more options including how to prompt the model, run:
```
python mixtral.py --help
```
For the Instruction model, make sure to follow the prompt format:
```
[INST] Instruction prompt [/INST]
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
details

View File

@ -1,23 +1,55 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import glob
import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import torch import torch
def convert(k, v, config):
v = v.to(torch.float16).numpy()
if "block_sparse_moe" not in k:
return [(k, v)]
if "gate" in k:
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
# From: layers.N.block_sparse_moe.w
# To: layers.N.experts.M.w
num_experts = args["moe"]["num_experts"]
key_path = k.split(".")
v = np.split(v, num_experts, axis=0)
if key_path[-1] == "w2":
v = [u.T for u in v]
w_name = key_path.pop()
key_path[-1] = "feed_forward.experts"
return [
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
]
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen/", default="Mixtral-8x7B-v0.1/",
help="The path to the Mixtral model. The MLX model weights will also be saved there.", help="The path to the Mixtral model. The MLX model weights will also be saved there.",
) )
args = parser.parse_args() args = parser.parse_args()
model_path = Path(args.model_path) model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez( with open("params.json") as fid:
str(model_path / "weights.npz"), args = json.load(fid)
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
) torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
for e, tf in enumerate(torch_files):
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
state = torch.load(tf)
new_state = {}
for k, v in state.items():
new_state.update(convert(k, v, args))
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)

View File

@ -2,6 +2,7 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
import glob
import json import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
@ -40,6 +41,26 @@ class RMSNorm(nn.Module):
return self.weight * output return self.weight * output
class RoPE(nn.RoPE):
def __init__(self, dims: int, traditional: bool = False):
super().__init__(dims, traditional)
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -56,7 +77,7 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True) self.rope = RoPE(args.head_dim, traditional=True)
def __call__( def __call__(
self, self,
@ -125,7 +146,10 @@ class MOEFeedForward(nn.Module):
gates = self.gate(x) gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = [] y = []
for xt, st, it in zip(x, scores, inds.tolist()): for xt, st, it in zip(x, scores, inds.tolist()):
@ -181,8 +205,9 @@ class Mixtral(nn.Module):
h = self.tok_embeddings(inputs) h = self.tok_embeddings(inputs)
mask = None mask = None
if h.shape[1] > 1: T = h.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype) mask = mask.astype(h.dtype)
if cache is None: if cache is None:
@ -191,7 +216,7 @@ class Mixtral(nn.Module):
for e, layer in enumerate(self.layers): for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e]) h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache return self.output(self.norm(h[:, T - 1 : T, :])), cache
class Tokenizer: class Tokenizer:
@ -222,10 +247,13 @@ class Tokenizer:
def load_model(folder: str, dtype=mx.float16): def load_model(folder: str, dtype=mx.float16):
model_path = Path(folder) model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f: with open("params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weight_files = glob.glob(str(model_path / "weights.*.npz"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mixtral(model_args) model = Mixtral(model_args)
@ -255,7 +283,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
type=str, type=str,
default="mixtral-8x7b-32kseqlen", default="Mixtral-8x7B-v0.1",
help="The path to the model weights, tokenizer, and config", help="The path to the model weights, tokenizer, and config",
) )
parser.add_argument( parser.add_argument(
@ -274,7 +302,7 @@ if __name__ == "__main__":
"--temp", "--temp",
help="The sampling temperature.", help="The sampling temperature.",
type=float, type=float,
default=1.0, default=0.0,
) )
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

1
mixtral/params.json Normal file
View File

@ -0,0 +1 @@
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}

1
phi2/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
weights.npz

57
phi2/README.md Normal file
View File

@ -0,0 +1,57 @@
# Phi-2
Phi-2 is a 2.7B parameter language model released by Microsoft with
performance that rivals much larger models.[^1] It was trained on a mixture of
GPT-4 outputs and clean web text.
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
precision.
## Setup
Download and convert the model:
```sh
python convert.py
```
This will make the `weights.npz` file which MLX can read.
## Generate
To generate text with the default prompt:
```sh
python phi2.py
```
Should give the output:
```
Answer: Mathematics is like a lighthouse that guides us through the darkness of
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
provides us with a clear path to navigate through complex problems. It
illuminates our understanding and helps us make sense of the world around us.
Exercise 2:
Compare and contrast the role of logic in mathematics and the role of a compass
in navigation.
Answer: Logic in mathematics is like a compass in navigation. It helps
```
To use your own prompt:
```sh
python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate>
```
To see a list of options run:
```sh
python phi2.py --help
```
[^1]: For more details on the model see the [blog post](
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)

24
phi2/convert.py Normal file
View File

@ -0,0 +1,24 @@
from transformers import AutoModelForCausalLM
import numpy as np
def replace_key(key: str) -> str:
if "wte.weight" in key:
key = "wte.weight"
if ".mlp" in key:
key = key.replace(".mlp", "")
return key
def convert():
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
if __name__ == "__main__":
convert()

222
phi2/phi2.py Normal file
View File

@ -0,0 +1,222 @@
import argparse
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
import mlx.core as mx
import mlx.nn as nn
import math
@dataclass
class ModelArgs:
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.fc1 = nn.Linear(dims, mlp_dims)
self.fc2 = nn.Linear(mlp_dims, dims)
self.act = nn.GELU(approx="precise")
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.fc2(self.act(self.fc1(h)))
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Phi2(nn.Module):
def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model():
model = Phi2(ModelArgs())
weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model()
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print("[INFO] Generating with Phi-2...", flush=True)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)

5
phi2/requirements.txt Normal file
View File

@ -0,0 +1,5 @@
einops
mlx
numpy
transformers
torch

View File

@ -27,7 +27,7 @@ Usage
------ ------
Although each component in this repository can be used by itself, the fastest Although each component in this repository can be used by itself, the fastest
way to get started is by using the `StableDiffusion` class from the `diffusion` way to get started is by using the `StableDiffusion` class from the `stable_diffusion`
module. module.
```python ```python

View File

@ -1,7 +1,7 @@
# whisper # Whisper
Speech recognition with Whisper in MLX. Whisper is a set of open source speech Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from Open AI, ranging from 39 million to 1.5 billion recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1]. parameters[^1].
### Setup ### Setup
@ -15,7 +15,7 @@ pip install -r requirements.txt
Install [`ffmpeg`](https://ffmpeg.org/): Install [`ffmpeg`](https://ffmpeg.org/):
``` ```
# on MacOS using Homebrew (https://brew.sh/) # on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg brew install ffmpeg
``` ```

View File

@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
logits = mlx_model(mels, tokens) logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16) self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self): def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id", fp16=False) options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options) result = decoding.decode(self.model, self.mels, options)

View File

@ -112,7 +112,7 @@ class DecodingOptions:
max_initial_timestamp: Optional[float] = 1.0 max_initial_timestamp: Optional[float] = 1.0
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
} }
@ -166,7 +166,8 @@ def convert(model, rules=None):
def torch_to_mlx( def torch_to_mlx(
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper: ) -> whisper.Whisper:
def convert_rblock(model, rules): def convert_rblock(model, rules):
children = dict(model.named_children()) children = dict(model.named_children())
@ -194,6 +195,6 @@ def torch_to_mlx(
def load_model( def load_model(
name: str, name: str,
download_root: str = None, download_root: str = None,
dtype : mx.Dtype = mx.float32, dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper: ) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root), dtype) return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@ -43,7 +43,7 @@ class ModelHolder:
model_name = None model_name = None
@classmethod @classmethod
def get_model(cls, model: str, dtype : mx.Dtype): def get_model(cls, model: str, dtype: mx.Dtype):
if cls.model is None or model != cls.model_name: if cls.model is None or model != cls.model_name:
cls.model = load_model(model, dtype=dtype) cls.model = load_model(model, dtype=dtype)
cls.model_name = model cls.model_name = model

View File

@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype) return super().__call__(x.astype(mx.float32)).astype(x.dtype)
@ -117,13 +118,19 @@ class ResidualAttentionBlock(nn.Module):
if self.cross_attn: if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv) return x, (kv, cross_kv)
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
def __init__( def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, self,
n_mels: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
): ):
super().__init__() super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
@ -134,8 +141,8 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)) x = nn.gelu(self.conv1(x)).astype(x.dtype)
x = nn.gelu(self.conv2(x)) x = nn.gelu(self.conv2(x)).astype(x.dtype)
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
def __init__( def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, self,
n_vocab: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
dtype: mx.Dtype = mx.float16,
): ):
super().__init__() super().__init__()
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
for _ in range(n_layer) for _ in range(n_layer)
] ]
self.ln = LayerNorm(n_state) self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype
)
def __call__(self, x, xa, kv_cache=None): def __call__(self, x, xa, kv_cache=None):
""" """