mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
commit
4c4317feda
@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
|||||||
Some more useful examples include:
|
Some more useful examples include:
|
||||||
|
|
||||||
- [Transformer language model](transformer_lm) training.
|
- [Transformer language model](transformer_lm) training.
|
||||||
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral).
|
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
|
||||||
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
|
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
|
||||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
- Parameter efficient fine-tuning with [LoRA](lora).
|
||||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||||
|
@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex
|
|||||||
|
|
||||||
```
|
```
|
||||||
python convert.py \
|
python convert.py \
|
||||||
--bert-model bert-base-uncased
|
--bert-model bert-base-uncased \
|
||||||
--mlx-model weights/bert-base-uncased.npz
|
--mlx-model weights/bert-base-uncased.npz
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import argparse
|
import argparse
|
||||||
import numpy
|
import numpy
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -35,74 +34,6 @@ model_configs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Minor update to the MultiHeadAttention module to ensure that the
|
|
||||||
projections use bias.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
num_heads: int,
|
|
||||||
query_input_dims: Optional[int] = None,
|
|
||||||
key_input_dims: Optional[int] = None,
|
|
||||||
value_input_dims: Optional[int] = None,
|
|
||||||
value_dims: Optional[int] = None,
|
|
||||||
value_output_dims: Optional[int] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if (dims % num_heads) != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
query_input_dims = query_input_dims or dims
|
|
||||||
key_input_dims = key_input_dims or dims
|
|
||||||
value_input_dims = value_input_dims or key_input_dims
|
|
||||||
value_dims = value_dims or dims
|
|
||||||
value_output_dims = value_output_dims or dims
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.query_proj = nn.Linear(query_input_dims, dims, True)
|
|
||||||
self.key_proj = nn.Linear(key_input_dims, dims, True)
|
|
||||||
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
|
|
||||||
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
_, S, _ = keys.shape
|
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
|
||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys
|
|
||||||
if mask is not None:
|
|
||||||
mask = self.convert_mask_to_additive_causal_mask(mask)
|
|
||||||
mask = mx.expand_dims(mask, (1, 2))
|
|
||||||
mask = mx.broadcast_to(mask, scores.shape)
|
|
||||||
scores = scores + mask.astype(scores.dtype)
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.out_proj(values_hat)
|
|
||||||
|
|
||||||
def convert_mask_to_additive_causal_mask(
|
|
||||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
|
||||||
) -> mx.array:
|
|
||||||
mask = mask == 0
|
|
||||||
mask = mask.astype(dtype) * -1e9
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A transformer encoder layer with (the original BERT) post-normalization.
|
A transformer encoder layer with (the original BERT) post-normalization.
|
||||||
@ -117,7 +48,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
self.attention = MultiHeadAttention(dims, num_heads)
|
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||||
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims)
|
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||||
@ -187,9 +118,15 @@ class Bert(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: mx.array,
|
input_ids: mx.array,
|
||||||
token_type_ids: mx.array,
|
token_type_ids: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
x = self.embeddings(input_ids, token_type_ids)
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||||
|
attention_mask = mx.log(attention_mask)
|
||||||
|
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||||
|
|
||||||
y = self.encoder(x, attention_mask)
|
y = self.encoder(x, attention_mask)
|
||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
mlx
|
mlx>=0.0.5
|
||||||
transformers
|
transformers
|
||||||
numpy
|
numpy
|
51
cifar/README.md
Normal file
51
cifar/README.md
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# CIFAR and ResNets
|
||||||
|
|
||||||
|
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
|
||||||
|
configurations in accordance with the original
|
||||||
|
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
|
||||||
|
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
||||||
|
load the dataset.
|
||||||
|
|
||||||
|
## Pre-requisites
|
||||||
|
|
||||||
|
Install the dependencies:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Run the example with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the example runs on the GPU. To run on the CPU, use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For all available options, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
After training with the default `resnet20` architecture for 100 epochs, you
|
||||||
|
should see the following results:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
|
||||||
|
Epoch: 99 | Test acc 0.807
|
||||||
|
```
|
||||||
|
|
||||||
|
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
|
||||||
|
or a `BatchNorm` layer. We intend to update this example once these features
|
||||||
|
are added.
|
30
cifar/dataset.py
Normal file
30
cifar/dataset.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from mlx.data.datasets import load_cifar10
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def get_cifar10(batch_size, root=None):
|
||||||
|
tr = load_cifar10(root=root)
|
||||||
|
|
||||||
|
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||||
|
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
x = x.astype("float32") / 255.0
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
tr_iter = (
|
||||||
|
tr.shuffle()
|
||||||
|
.to_stream()
|
||||||
|
.image_random_h_flip("image", prob=0.5)
|
||||||
|
.pad("image", 0, 4, 4, 0.0)
|
||||||
|
.pad("image", 1, 4, 4, 0.0)
|
||||||
|
.image_random_crop("image", 32, 32)
|
||||||
|
.key_transform("image", normalize)
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
test = load_cifar10(root=root, train=False)
|
||||||
|
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||||
|
|
||||||
|
return tr_iter, test_iter
|
120
cifar/main.py
Normal file
120
cifar/main.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import resnet
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from dataset import get_cifar10
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--arch",
|
||||||
|
type=str,
|
||||||
|
default="resnet20",
|
||||||
|
choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
|
||||||
|
help="model architecture",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_fn(model, inp, tgt):
|
||||||
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_iter, optimizer, epoch):
|
||||||
|
def train_step(model, inp, tgt):
|
||||||
|
output = model(inp)
|
||||||
|
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||||
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
accs = []
|
||||||
|
samples_per_sec = []
|
||||||
|
|
||||||
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
|
x = mx.array(batch["image"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
|
(loss, acc), grads = train_step_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
loss = loss.item()
|
||||||
|
acc = acc.item()
|
||||||
|
losses.append(loss)
|
||||||
|
accs.append(acc)
|
||||||
|
throughput = x.shape[0] / (toc - tic)
|
||||||
|
samples_per_sec.append(throughput)
|
||||||
|
if batch_counter % 10 == 0:
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
|
f"Train loss {loss:.3f}",
|
||||||
|
f"Train acc {acc:.3f}",
|
||||||
|
f"Throughput: {throughput:.2f} images/second",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
|
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||||
|
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch(model, test_iter, epoch):
|
||||||
|
accs = []
|
||||||
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
|
x = mx.array(batch["image"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
acc = eval_fn(model, x, y)
|
||||||
|
acc_value = acc.item()
|
||||||
|
accs.append(acc_value)
|
||||||
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
|
return mean_acc
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model = getattr(resnet, args.arch)()
|
||||||
|
|
||||||
|
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||||
|
|
||||||
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
|
|
||||||
|
train_data, test_data = get_cifar10(args.batch_size)
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch: {epoch}",
|
||||||
|
f"avg. Train loss {tr_loss.item():.3f}",
|
||||||
|
f"avg. Train acc {tr_acc.item():.3f}",
|
||||||
|
f"Throughput: {throughput.item():.2f} images/sec",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_acc = test_epoch(model, test_data, epoch)
|
||||||
|
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||||
|
|
||||||
|
train_data.reset()
|
||||||
|
test_data.reset()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
main(args)
|
2
cifar/requirements.txt
Normal file
2
cifar/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
mlx
|
||||||
|
mlx-data
|
131
cifar/resnet.py
Normal file
131
cifar/resnet.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
||||||
|
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
||||||
|
|
||||||
|
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ResNet",
|
||||||
|
"resnet20",
|
||||||
|
"resnet32",
|
||||||
|
"resnet44",
|
||||||
|
"resnet56",
|
||||||
|
"resnet110",
|
||||||
|
"resnet1202",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutA(nn.Module):
|
||||||
|
def __init__(self, dims):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return mx.pad(
|
||||||
|
x[:, ::2, ::2, :],
|
||||||
|
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a ResNet block with two convolutional layers and a skip connection.
|
||||||
|
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dims, dims, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn1 = nn.LayerNorm(dims)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn2 = nn.LayerNorm(dims)
|
||||||
|
|
||||||
|
if stride != 1:
|
||||||
|
self.shortcut = ShortcutA(dims)
|
||||||
|
else:
|
||||||
|
self.shortcut = None
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
out = nn.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bn2(self.conv2(out))
|
||||||
|
if self.shortcut is None:
|
||||||
|
out += x
|
||||||
|
else:
|
||||||
|
out += self.shortcut(x)
|
||||||
|
out = nn.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
"""
|
||||||
|
Creates a ResNet model for CIFAR-10, as specified in the original paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block, num_blocks, num_classes=10):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.LayerNorm(16)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||||
|
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||||
|
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
|
||||||
|
|
||||||
|
self.linear = nn.Linear(64, num_classes)
|
||||||
|
|
||||||
|
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(in_dims, dims, stride))
|
||||||
|
in_dims = dims
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def num_params(self):
|
||||||
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
||||||
|
return nparams
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = nn.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def resnet20(**kwargs):
|
||||||
|
return ResNet(Block, [3, 3, 3], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet32(**kwargs):
|
||||||
|
return ResNet(Block, [5, 5, 5], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet44(**kwargs):
|
||||||
|
return ResNet(Block, [7, 7, 7], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet56(**kwargs):
|
||||||
|
return ResNet(Block, [9, 9, 9], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet110(**kwargs):
|
||||||
|
return ResNet(Block, [18, 18, 18], **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet1202(**kwargs):
|
||||||
|
return ResNet(Block, [200, 200, 200], **kwargs)
|
@ -315,7 +315,7 @@ def load_model(model_path):
|
|||||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
if config.get("vocab_size", -1) < 0:
|
if config.get("vocab_size", -1) < 0:
|
||||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||||
unused = ["multiple_of", "ffn_dim_multiplie"]
|
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
|
||||||
for k in unused:
|
for k in unused:
|
||||||
if k in config:
|
if k in config:
|
||||||
config.pop(k)
|
config.pop(k)
|
||||||
|
@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar
|
|||||||
```
|
```
|
||||||
|
|
||||||
If you do not have access to the Llama weights you will need to [request
|
If you do not have access to the Llama weights you will need to [request
|
||||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||||
from Meta.
|
from Meta.
|
||||||
|
|
||||||
Convert the model with:
|
Convert the model with:
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
|
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
|
||||||
|
|
||||||
|
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
|
||||||
|
|
||||||
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
|
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
@ -16,37 +18,56 @@ brew install git-lfs
|
|||||||
|
|
||||||
Download the models from Hugging Face:
|
Download the models from Hugging Face:
|
||||||
|
|
||||||
|
For the base model use:
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
|
export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
After that's done, combine the files:
|
For the instruction fine-tuned model use:
|
||||||
|
|
||||||
```
|
```
|
||||||
cd mixtral-8x7b-32kseqlen/
|
export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
|
||||||
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
|
```
|
||||||
|
|
||||||
|
Then run:
|
||||||
|
|
||||||
|
```
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
|
||||||
|
cd $MIXTRAL_MODEL/ && \
|
||||||
|
git lfs pull --include "consolidated.*.pt" && \
|
||||||
|
git lfs pull --include "tokenizer.model"
|
||||||
```
|
```
|
||||||
|
|
||||||
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
|
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
|
||||||
MLX can read them:
|
MLX can read them:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
python convert.py --model_path $MIXTRAL_MODEL/
|
||||||
```
|
```
|
||||||
|
|
||||||
The conversion script will save the converted weights in the same location.
|
The conversion script will save the converted weights in the same location.
|
||||||
|
|
||||||
After that's done, if you want to clean some stuff up:
|
|
||||||
|
|
||||||
```
|
|
||||||
rm mixtral-8x7b-32kseqlen/*.pth*
|
|
||||||
```
|
|
||||||
|
|
||||||
### Generate
|
### Generate
|
||||||
|
|
||||||
As easy as:
|
As easy as:
|
||||||
|
|
||||||
```
|
```
|
||||||
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
|
python mixtral.py --model_path $MIXTRAL_MODEL/
|
||||||
```
|
```
|
||||||
|
|
||||||
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
|
For more options including how to prompt the model, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python mixtral.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
For the Instruction model, make sure to follow the prompt format:
|
||||||
|
|
||||||
|
```
|
||||||
|
[INST] Instruction prompt [/INST]
|
||||||
|
```
|
||||||
|
|
||||||
|
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
|
||||||
|
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
|
||||||
|
details
|
||||||
|
@ -1,23 +1,55 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def convert(k, v, config):
|
||||||
|
v = v.to(torch.float16).numpy()
|
||||||
|
if "block_sparse_moe" not in k:
|
||||||
|
return [(k, v)]
|
||||||
|
if "gate" in k:
|
||||||
|
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
|
||||||
|
|
||||||
|
# From: layers.N.block_sparse_moe.w
|
||||||
|
# To: layers.N.experts.M.w
|
||||||
|
num_experts = args["moe"]["num_experts"]
|
||||||
|
key_path = k.split(".")
|
||||||
|
v = np.split(v, num_experts, axis=0)
|
||||||
|
if key_path[-1] == "w2":
|
||||||
|
v = [u.T for u in v]
|
||||||
|
|
||||||
|
w_name = key_path.pop()
|
||||||
|
key_path[-1] = "feed_forward.experts"
|
||||||
|
return [
|
||||||
|
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
|
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mixtral-8x7b-32kseqlen/",
|
default="Mixtral-8x7B-v0.1/",
|
||||||
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
|
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model_path = Path(args.model_path)
|
model_path = Path(args.model_path)
|
||||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
|
||||||
np.savez(
|
with open("params.json") as fid:
|
||||||
str(model_path / "weights.npz"),
|
args = json.load(fid)
|
||||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
|
|
||||||
)
|
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||||
|
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||||
|
for e, tf in enumerate(torch_files):
|
||||||
|
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
|
||||||
|
state = torch.load(tf)
|
||||||
|
new_state = {}
|
||||||
|
for k, v in state.items():
|
||||||
|
new_state.update(convert(k, v, args))
|
||||||
|
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -40,6 +41,26 @@ class RMSNorm(nn.Module):
|
|||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.RoPE):
|
||||||
|
def __init__(self, dims: int, traditional: bool = False):
|
||||||
|
super().__init__(dims, traditional)
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
shape = x.shape
|
||||||
|
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||||
|
N = x.shape[1] + offset
|
||||||
|
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||||
|
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
rope = (
|
||||||
|
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||||
|
)
|
||||||
|
rx = rope(costheta, sintheta, x)
|
||||||
|
|
||||||
|
return mx.reshape(rx, shape)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -56,7 +77,7 @@ class Attention(nn.Module):
|
|||||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
self.rope = RoPE(args.head_dim, traditional=True)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -125,7 +146,10 @@ class MOEFeedForward(nn.Module):
|
|||||||
|
|
||||||
gates = self.gate(x)
|
gates = self.gate(x)
|
||||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
scores = mx.softmax(
|
||||||
|
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||||
|
axis=-1,
|
||||||
|
).astype(gates.dtype)
|
||||||
|
|
||||||
y = []
|
y = []
|
||||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
@ -181,8 +205,9 @@ class Mixtral(nn.Module):
|
|||||||
h = self.tok_embeddings(inputs)
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if h.shape[1] > 1:
|
T = h.shape[1]
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
@ -191,7 +216,7 @@ class Mixtral(nn.Module):
|
|||||||
for e, layer in enumerate(self.layers):
|
for e, layer in enumerate(self.layers):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h, cache[e] = layer(h, mask, cache[e])
|
||||||
|
|
||||||
return self.output(self.norm(h)), cache
|
return self.output(self.norm(h[:, T - 1 : T, :])), cache
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
@ -222,10 +247,13 @@ class Tokenizer:
|
|||||||
def load_model(folder: str, dtype=mx.float16):
|
def load_model(folder: str, dtype=mx.float16):
|
||||||
model_path = Path(folder)
|
model_path = Path(folder)
|
||||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||||
with open(model_path / "params.json", "r") as f:
|
with open("params.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
model_args = ModelArgs(**config)
|
model_args = ModelArgs(**config)
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||||
|
weights = {}
|
||||||
|
for wf in weight_files:
|
||||||
|
weights.update(mx.load(wf).items())
|
||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
model = Mixtral(model_args)
|
model = Mixtral(model_args)
|
||||||
@ -255,7 +283,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mixtral-8x7b-32kseqlen",
|
default="Mixtral-8x7B-v0.1",
|
||||||
help="The path to the model weights, tokenizer, and config",
|
help="The path to the model weights, tokenizer, and config",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -274,7 +302,7 @@ if __name__ == "__main__":
|
|||||||
"--temp",
|
"--temp",
|
||||||
help="The sampling temperature.",
|
help="The sampling temperature.",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=0.0,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
|
||||||
|
1
mixtral/params.json
Normal file
1
mixtral/params.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}
|
1
phi2/.gitignore
vendored
Normal file
1
phi2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
weights.npz
|
57
phi2/README.md
Normal file
57
phi2/README.md
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Phi-2
|
||||||
|
|
||||||
|
Phi-2 is a 2.7B parameter language model released by Microsoft with
|
||||||
|
performance that rivals much larger models.[^1] It was trained on a mixture of
|
||||||
|
GPT-4 outputs and clean web text.
|
||||||
|
|
||||||
|
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
|
||||||
|
precision.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Download and convert the model:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python convert.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will make the `weights.npz` file which MLX can read.
|
||||||
|
|
||||||
|
## Generate
|
||||||
|
|
||||||
|
To generate text with the default prompt:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python phi2.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Should give the output:
|
||||||
|
|
||||||
|
```
|
||||||
|
Answer: Mathematics is like a lighthouse that guides us through the darkness of
|
||||||
|
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
|
||||||
|
provides us with a clear path to navigate through complex problems. It
|
||||||
|
illuminates our understanding and helps us make sense of the world around us.
|
||||||
|
|
||||||
|
Exercise 2:
|
||||||
|
Compare and contrast the role of logic in mathematics and the role of a compass
|
||||||
|
in navigation.
|
||||||
|
|
||||||
|
Answer: Logic in mathematics is like a compass in navigation. It helps
|
||||||
|
```
|
||||||
|
|
||||||
|
To use your own prompt:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate>
|
||||||
|
```
|
||||||
|
|
||||||
|
To see a list of options run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python phi2.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: For more details on the model see the [blog post](
|
||||||
|
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
|
||||||
|
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
|
24
phi2/convert.py
Normal file
24
phi2/convert.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def replace_key(key: str) -> str:
|
||||||
|
if "wte.weight" in key:
|
||||||
|
key = "wte.weight"
|
||||||
|
|
||||||
|
if ".mlp" in key:
|
||||||
|
key = key.replace(".mlp", "")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def convert():
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||||
|
np.savez("weights.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert()
|
222
phi2/phi2.py
Normal file
222
phi2/phi2.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
max_sequence_length: int = 2048
|
||||||
|
num_vocab: int = 51200
|
||||||
|
model_dim: int = 2560
|
||||||
|
num_heads: int = 32
|
||||||
|
num_layers: int = 32
|
||||||
|
rotary_dim: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RoPEAttention(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||||
|
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||||
|
self.out_proj = nn.Linear(dims, dims)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||||
|
|
||||||
|
# Extract some shapes
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
queries = queries.astype(mx.float32)
|
||||||
|
keys = keys.astype(mx.float32)
|
||||||
|
|
||||||
|
# Finally perform the attention computation
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||||
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelBlock(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
dims = config.model_dim
|
||||||
|
mlp_dims = dims * 4
|
||||||
|
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||||
|
self.ln = LayerNorm(dims)
|
||||||
|
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||||
|
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||||
|
self.act = nn.GELU(approx="precise")
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
h = self.ln(x)
|
||||||
|
attn_h, cache = self.mixer(h, mask, cache)
|
||||||
|
ff_h = self.fc2(self.act(self.fc1(h)))
|
||||||
|
return attn_h + ff_h + x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for e, layer in enumerate(self.h):
|
||||||
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class OutputHead(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
self.ln = LayerNorm(config.model_dim)
|
||||||
|
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
return self.linear(self.ln(inputs))
|
||||||
|
|
||||||
|
|
||||||
|
class Phi2(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||||
|
self.transformer = TransformerDecoder(config)
|
||||||
|
self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache: mx.array = None,
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if x.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
y, cache = self.transformer(x, mask, cache)
|
||||||
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
|
||||||
|
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||||
|
def sample(logits):
|
||||||
|
if temp == 0:
|
||||||
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
logits, cache = model(prompt)
|
||||||
|
y = sample(logits[:, -1, :])
|
||||||
|
yield y
|
||||||
|
|
||||||
|
while True:
|
||||||
|
logits, cache = model(y[:, None], cache=cache)
|
||||||
|
y = sample(logits.squeeze(1))
|
||||||
|
yield y
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
model = Phi2(ModelArgs())
|
||||||
|
weights = mx.load("weights.npz")
|
||||||
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="The message to be processed by the model",
|
||||||
|
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp",
|
||||||
|
help="The sampling temperature.",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model, tokenizer = load_model()
|
||||||
|
|
||||||
|
prompt = tokenizer(
|
||||||
|
args.prompt,
|
||||||
|
return_tensors="np",
|
||||||
|
return_attention_mask=False,
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
prompt = mx.array(prompt)
|
||||||
|
|
||||||
|
print("[INFO] Generating with Phi-2...", flush=True)
|
||||||
|
print(args.prompt, end="", flush=True)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
if (len(tokens) % 10) == 0:
|
||||||
|
mx.eval(tokens)
|
||||||
|
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
|
||||||
|
|
||||||
|
if eos_index is not None:
|
||||||
|
tokens = tokens[:eos_index]
|
||||||
|
|
||||||
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
print(s, end="", flush=True)
|
||||||
|
tokens = []
|
||||||
|
if eos_index is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
mx.eval(tokens)
|
||||||
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
print(s, flush=True)
|
5
phi2/requirements.txt
Normal file
5
phi2/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
einops
|
||||||
|
mlx
|
||||||
|
numpy
|
||||||
|
transformers
|
||||||
|
torch
|
@ -27,7 +27,7 @@ Usage
|
|||||||
------
|
------
|
||||||
|
|
||||||
Although each component in this repository can be used by itself, the fastest
|
Although each component in this repository can be used by itself, the fastest
|
||||||
way to get started is by using the `StableDiffusion` class from the `diffusion`
|
way to get started is by using the `StableDiffusion` class from the `stable_diffusion`
|
||||||
module.
|
module.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# whisper
|
# Whisper
|
||||||
|
|
||||||
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
||||||
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
||||||
@ -15,7 +15,7 @@ pip install -r requirements.txt
|
|||||||
Install [`ffmpeg`](https://ffmpeg.org/):
|
Install [`ffmpeg`](https://ffmpeg.org/):
|
||||||
|
|
||||||
```
|
```
|
||||||
# on MacOS using Homebrew (https://brew.sh/)
|
# on macOS using Homebrew (https://brew.sh/)
|
||||||
brew install ffmpeg
|
brew install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
|
|||||||
logits = mlx_model(mels, tokens)
|
logits = mlx_model(mels, tokens)
|
||||||
self.assertEqual(logits.dtype, mx.float16)
|
self.assertEqual(logits.dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
def test_decode_lang(self):
|
def test_decode_lang(self):
|
||||||
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||||
result = decoding.decode(self.model, self.mels, options)
|
result = decoding.decode(self.model, self.mels, options)
|
||||||
|
@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
|
|||||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -166,7 +166,8 @@ def convert(model, rules=None):
|
|||||||
|
|
||||||
|
|
||||||
def torch_to_mlx(
|
def torch_to_mlx(
|
||||||
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
|
torch_model: torch_whisper.Whisper,
|
||||||
|
dtype: mx.Dtype = mx.float16,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
def convert_rblock(model, rules):
|
def convert_rblock(model, rules):
|
||||||
children = dict(model.named_children())
|
children = dict(model.named_children())
|
||||||
|
@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
||||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||||
@ -117,13 +118,19 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||||
x += y
|
x += y
|
||||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
|
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||||
return x, (kv, cross_kv)
|
return x, (kv, cross_kv)
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
self,
|
||||||
|
n_mels: int,
|
||||||
|
n_ctx: int,
|
||||||
|
n_state: int,
|
||||||
|
n_head: int,
|
||||||
|
n_layer: int,
|
||||||
|
dtype: mx.Dtype = mx.float16,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
@ -134,8 +141,8 @@ class AudioEncoder(nn.Module):
|
|||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = nn.gelu(self.conv1(x))
|
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||||
x = nn.gelu(self.conv2(x))
|
x = nn.gelu(self.conv2(x)).astype(x.dtype)
|
||||||
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
||||||
x = x + self._positional_embedding
|
x = x + self._positional_embedding
|
||||||
|
|
||||||
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
self,
|
||||||
|
n_vocab: int,
|
||||||
|
n_ctx: int,
|
||||||
|
n_state: int,
|
||||||
|
n_head: int,
|
||||||
|
n_layer: int,
|
||||||
|
dtype: mx.Dtype = mx.float16,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
|
|||||||
for _ in range(n_layer)
|
for _ in range(n_layer)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(n_state)
|
self.ln = LayerNorm(n_state)
|
||||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
|
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, xa, kv_cache=None):
|
def __call__(self, x, xa, kv_cache=None):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user