mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
commit
4c4317feda
@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||
Some more useful examples include:
|
||||
|
||||
- [Transformer language model](transformer_lm) training.
|
||||
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral).
|
||||
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
|
||||
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
|
||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||
|
@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex
|
||||
|
||||
```
|
||||
python convert.py \
|
||||
--bert-model bert-base-uncased
|
||||
--bert-model bert-base-uncased \
|
||||
--mlx-model weights/bert-base-uncased.npz
|
||||
```
|
||||
|
||||
|
@ -8,7 +8,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import argparse
|
||||
import numpy
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -35,74 +34,6 @@ model_configs = {
|
||||
}
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Minor update to the MultiHeadAttention module to ensure that the
|
||||
projections use bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = nn.Linear(query_input_dims, dims, True)
|
||||
self.key_proj = nn.Linear(key_input_dims, dims, True)
|
||||
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
|
||||
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
mask = self.convert_mask_to_additive_causal_mask(mask)
|
||||
mask = mx.expand_dims(mask, (1, 2))
|
||||
mask = mx.broadcast_to(mask, scores.shape)
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
def convert_mask_to_additive_causal_mask(
|
||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
mask = mask == 0
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
A transformer encoder layer with (the original BERT) post-normalization.
|
||||
@ -117,7 +48,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = MultiHeadAttention(dims, num_heads)
|
||||
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||
@ -187,9 +118,15 @@ class Bert(nn.Module):
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
attention_mask: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
|
||||
y = self.encoder(x, attention_mask)
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
mlx
|
||||
mlx>=0.0.5
|
||||
transformers
|
||||
numpy
|
||||
numpy
|
||||
|
51
cifar/README.md
Normal file
51
cifar/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# CIFAR and ResNets
|
||||
|
||||
An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
|
||||
configurations in accordance with the original
|
||||
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
|
||||
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
||||
load the dataset.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
Run the example with:
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
By default the example runs on the GPU. To run on the CPU, use:
|
||||
|
||||
```
|
||||
python main.py --cpu
|
||||
```
|
||||
|
||||
For all available options, run:
|
||||
|
||||
```
|
||||
python main.py --help
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
After training with the default `resnet20` architecture for 100 epochs, you
|
||||
should see the following results:
|
||||
|
||||
```
|
||||
Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec
|
||||
Epoch: 99 | Test acc 0.807
|
||||
```
|
||||
|
||||
Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules,
|
||||
or a `BatchNorm` layer. We intend to update this example once these features
|
||||
are added.
|
30
cifar/dataset.py
Normal file
30
cifar/dataset.py
Normal file
@ -0,0 +1,30 @@
|
||||
import mlx.core as mx
|
||||
from mlx.data.datasets import load_cifar10
|
||||
import math
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
|
||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
|
||||
def normalize(x):
|
||||
x = x.astype("float32") / 255.0
|
||||
return (x - mean) / std
|
||||
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.to_stream()
|
||||
.image_random_h_flip("image", prob=0.5)
|
||||
.pad("image", 0, 4, 4, 0.0)
|
||||
.pad("image", 1, 4, 4, 0.0)
|
||||
.image_random_crop("image", 32, 32)
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||
|
||||
return tr_iter, test_iter
|
120
cifar/main.py
Normal file
120
cifar/main.py
Normal file
@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import time
|
||||
import resnet
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as optim
|
||||
from dataset import get_cifar10
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=True)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
type=str,
|
||||
default="resnet20",
|
||||
choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
|
||||
help="model architecture",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||
|
||||
|
||||
def eval_fn(model, inp, tgt):
|
||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||
|
||||
|
||||
def train_epoch(model, train_iter, optimizer, epoch):
|
||||
def train_step(model, inp, tgt):
|
||||
output = model(inp)
|
||||
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
return loss, acc
|
||||
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
|
||||
for batch_counter, batch in enumerate(train_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
(loss, acc), grads = train_step_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
losses.append(loss)
|
||||
accs.append(acc)
|
||||
throughput = x.shape[0] / (toc - tic)
|
||||
samples_per_sec.append(throughput)
|
||||
if batch_counter % 10 == 0:
|
||||
print(
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||
f"Train loss {loss:.3f}",
|
||||
f"Train acc {acc:.3f}",
|
||||
f"Throughput: {throughput:.2f} images/second",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
mean_tr_loss = mx.mean(mx.array(losses))
|
||||
mean_tr_acc = mx.mean(mx.array(accs))
|
||||
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||
|
||||
|
||||
def test_epoch(model, test_iter, epoch):
|
||||
accs = []
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
acc = eval_fn(model, x, y)
|
||||
acc_value = acc.item()
|
||||
accs.append(acc_value)
|
||||
mean_acc = mx.mean(mx.array(accs))
|
||||
return mean_acc
|
||||
|
||||
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model = getattr(resnet, args.arch)()
|
||||
|
||||
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.lr)
|
||||
|
||||
train_data, test_data = get_cifar10(args.batch_size)
|
||||
for epoch in range(args.epochs):
|
||||
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||
print(
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch: {epoch}",
|
||||
f"avg. Train loss {tr_loss.item():.3f}",
|
||||
f"avg. Train acc {tr_acc.item():.3f}",
|
||||
f"Throughput: {throughput.item():.2f} images/sec",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
test_acc = test_epoch(model, test_data, epoch)
|
||||
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||
|
||||
train_data.reset()
|
||||
test_data.reset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main(args)
|
2
cifar/requirements.txt
Normal file
2
cifar/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
mlx
|
||||
mlx-data
|
131
cifar/resnet.py
Normal file
131
cifar/resnet.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""
|
||||
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
|
||||
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
|
||||
|
||||
There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResNet",
|
||||
"resnet20",
|
||||
"resnet32",
|
||||
"resnet44",
|
||||
"resnet56",
|
||||
"resnet110",
|
||||
"resnet1202",
|
||||
]
|
||||
|
||||
|
||||
class ShortcutA(nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(
|
||||
x[:, ::2, ::2, :],
|
||||
pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
|
||||
)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
Implements a ResNet block with two convolutional layers and a skip connection.
|
||||
As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
|
||||
"""
|
||||
|
||||
def __init__(self, in_dims, dims, stride=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.LayerNorm(dims)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
dims, dims, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.LayerNorm(dims)
|
||||
|
||||
if stride != 1:
|
||||
self.shortcut = ShortcutA(dims)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
def __call__(self, x):
|
||||
out = nn.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
if self.shortcut is None:
|
||||
out += x
|
||||
else:
|
||||
out += self.shortcut(x)
|
||||
out = nn.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""
|
||||
Creates a ResNet model for CIFAR-10, as specified in the original paper.
|
||||
"""
|
||||
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.LayerNorm(16)
|
||||
|
||||
self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)
|
||||
|
||||
self.linear = nn.Linear(64, num_classes)
|
||||
|
||||
def _make_layer(self, block, in_dims, dims, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(in_dims, dims, stride))
|
||||
in_dims = dims
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def num_params(self):
|
||||
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
||||
return nparams
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.relu(self.bn1(self.conv1(x)))
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def resnet20(**kwargs):
|
||||
return ResNet(Block, [3, 3, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet32(**kwargs):
|
||||
return ResNet(Block, [5, 5, 5], **kwargs)
|
||||
|
||||
|
||||
def resnet44(**kwargs):
|
||||
return ResNet(Block, [7, 7, 7], **kwargs)
|
||||
|
||||
|
||||
def resnet56(**kwargs):
|
||||
return ResNet(Block, [9, 9, 9], **kwargs)
|
||||
|
||||
|
||||
def resnet110(**kwargs):
|
||||
return ResNet(Block, [18, 18, 18], **kwargs)
|
||||
|
||||
|
||||
def resnet1202(**kwargs):
|
||||
return ResNet(Block, [200, 200, 200], **kwargs)
|
@ -315,7 +315,7 @@ def load_model(model_path):
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
unused = ["multiple_of", "ffn_dim_multiplie"]
|
||||
unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
|
@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar
|
||||
```
|
||||
|
||||
If you do not have access to the Llama weights you will need to [request
|
||||
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
from Meta.
|
||||
|
||||
Convert the model with:
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.
|
||||
|
||||
This example also supports the instruction fine-tuned Mixtral model.[^instruct]
|
||||
|
||||
Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run.
|
||||
|
||||
### Setup
|
||||
@ -16,37 +18,56 @@ brew install git-lfs
|
||||
|
||||
Download the models from Hugging Face:
|
||||
|
||||
For the base model use:
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
|
||||
export MIXTRAL_MODEL=Mixtral-8x7B-v0.1
|
||||
```
|
||||
|
||||
After that's done, combine the files:
|
||||
For the instruction fine-tuned model use:
|
||||
|
||||
```
|
||||
cd mixtral-8x7b-32kseqlen/
|
||||
cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth
|
||||
export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1
|
||||
```
|
||||
|
||||
Then run:
|
||||
|
||||
```
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/
|
||||
cd $MIXTRAL_MODEL/ && \
|
||||
git lfs pull --include "consolidated.*.pt" && \
|
||||
git lfs pull --include "tokenizer.model"
|
||||
```
|
||||
|
||||
Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so
|
||||
MLX can read them:
|
||||
|
||||
```
|
||||
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
||||
python convert.py --model_path $MIXTRAL_MODEL/
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
After that's done, if you want to clean some stuff up:
|
||||
|
||||
```
|
||||
rm mixtral-8x7b-32kseqlen/*.pth*
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
As easy as:
|
||||
|
||||
```
|
||||
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
|
||||
python mixtral.py --model_path $MIXTRAL_MODEL/
|
||||
```
|
||||
|
||||
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.
|
||||
For more options including how to prompt the model, run:
|
||||
|
||||
```
|
||||
python mixtral.py --help
|
||||
```
|
||||
|
||||
For the Instruction model, make sure to follow the prompt format:
|
||||
|
||||
```
|
||||
[INST] Instruction prompt [/INST]
|
||||
```
|
||||
|
||||
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details.
|
||||
[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more
|
||||
details
|
||||
|
@ -1,23 +1,55 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
|
||||
def convert(k, v, config):
|
||||
v = v.to(torch.float16).numpy()
|
||||
if "block_sparse_moe" not in k:
|
||||
return [(k, v)]
|
||||
if "gate" in k:
|
||||
return [(k.replace("block_sparse_moe", "feed_forward"), v)]
|
||||
|
||||
# From: layers.N.block_sparse_moe.w
|
||||
# To: layers.N.experts.M.w
|
||||
num_experts = args["moe"]["num_experts"]
|
||||
key_path = k.split(".")
|
||||
v = np.split(v, num_experts, axis=0)
|
||||
if key_path[-1] == "w2":
|
||||
v = [u.T for u in v]
|
||||
|
||||
w_name = key_path.pop()
|
||||
key_path[-1] = "feed_forward.experts"
|
||||
return [
|
||||
(".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mixtral-8x7b-32kseqlen/",
|
||||
default="Mixtral-8x7B-v0.1/",
|
||||
help="The path to the Mixtral model. The MLX model weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()},
|
||||
)
|
||||
|
||||
with open("params.json") as fid:
|
||||
args = json.load(fid)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||
for e, tf in enumerate(torch_files):
|
||||
print(f"[INFO] Converting file {e + 1}/{len(torch_files)}")
|
||||
state = torch.load(tf)
|
||||
new_state = {}
|
||||
for k, v in state.items():
|
||||
new_state.update(convert(k, v, args))
|
||||
np.savez(str(model_path / f"weights.{e}.npz"), **new_state)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
@ -40,6 +41,26 @@ class RMSNorm(nn.Module):
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class RoPE(nn.RoPE):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__(dims, traditional)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=1000000, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -56,7 +77,7 @@ class Attention(nn.Module):
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
||||
self.rope = RoPE(args.head_dim, traditional=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -125,7 +146,10 @@ class MOEFeedForward(nn.Module):
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
@ -181,8 +205,9 @@ class Mixtral(nn.Module):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
@ -191,7 +216,7 @@ class Mixtral(nn.Module):
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.output(self.norm(h)), cache
|
||||
return self.output(self.norm(h[:, T - 1 : T, :])), cache
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
@ -222,10 +247,13 @@ class Tokenizer:
|
||||
def load_model(folder: str, dtype=mx.float16):
|
||||
model_path = Path(folder)
|
||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
with open("params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weight_files = glob.glob(str(model_path / "weights.*.npz"))
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mixtral(model_args)
|
||||
@ -255,7 +283,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mixtral-8x7b-32kseqlen",
|
||||
default="Mixtral-8x7B-v0.1",
|
||||
help="The path to the model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -274,7 +302,7 @@ if __name__ == "__main__":
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=1.0,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
|
1
mixtral/params.json
Normal file
1
mixtral/params.json
Normal file
@ -0,0 +1 @@
|
||||
{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}}
|
1
phi2/.gitignore
vendored
Normal file
1
phi2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
weights.npz
|
57
phi2/README.md
Normal file
57
phi2/README.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Phi-2
|
||||
|
||||
Phi-2 is a 2.7B parameter language model released by Microsoft with
|
||||
performance that rivals much larger models.[^1] It was trained on a mixture of
|
||||
GPT-4 outputs and clean web text.
|
||||
|
||||
Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit
|
||||
precision.
|
||||
|
||||
## Setup
|
||||
|
||||
Download and convert the model:
|
||||
|
||||
```sh
|
||||
python convert.py
|
||||
```
|
||||
|
||||
This will make the `weights.npz` file which MLX can read.
|
||||
|
||||
## Generate
|
||||
|
||||
To generate text with the default prompt:
|
||||
|
||||
```sh
|
||||
python phi2.py
|
||||
```
|
||||
|
||||
Should give the output:
|
||||
|
||||
```
|
||||
Answer: Mathematics is like a lighthouse that guides us through the darkness of
|
||||
uncertainty. Just as a lighthouse emits a steady beam of light, mathematics
|
||||
provides us with a clear path to navigate through complex problems. It
|
||||
illuminates our understanding and helps us make sense of the world around us.
|
||||
|
||||
Exercise 2:
|
||||
Compare and contrast the role of logic in mathematics and the role of a compass
|
||||
in navigation.
|
||||
|
||||
Answer: Logic in mathematics is like a compass in navigation. It helps
|
||||
```
|
||||
|
||||
To use your own prompt:
|
||||
|
||||
```sh
|
||||
python phi2.py --prompt <your prompt here> --max_tokens <max_tokens_to_generate>
|
||||
```
|
||||
|
||||
To see a list of options run:
|
||||
|
||||
```sh
|
||||
python phi2.py --help
|
||||
```
|
||||
|
||||
[^1]: For more details on the model see the [blog post](
|
||||
https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/)
|
||||
and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2)
|
24
phi2/convert.py
Normal file
24
phi2/convert.py
Normal file
@ -0,0 +1,24 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import numpy as np
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
if "wte.weight" in key:
|
||||
key = "wte.weight"
|
||||
|
||||
if ".mlp" in key:
|
||||
key = key.replace(".mlp", "")
|
||||
return key
|
||||
|
||||
|
||||
def convert():
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
222
phi2/phi2.py
Normal file
222
phi2/phi2.py
Normal file
@ -0,0 +1,222 @@
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
max_sequence_length: int = 2048
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h, cache = self.mixer(h, mask, cache)
|
||||
ff_h = self.fc2(self.act(self.fc1(h)))
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.ln = LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Phi2(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, cache = model(prompt)
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache=cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
|
||||
|
||||
def load_model():
|
||||
model = Phi2(ModelArgs())
|
||||
weights = mx.load("weights.npz")
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model()
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
print("[INFO] Generating with Phi-2...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
|
||||
|
||||
if eos_index is not None:
|
||||
tokens = tokens[:eos_index]
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
if eos_index is not None:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
5
phi2/requirements.txt
Normal file
5
phi2/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
einops
|
||||
mlx
|
||||
numpy
|
||||
transformers
|
||||
torch
|
@ -27,7 +27,7 @@ Usage
|
||||
------
|
||||
|
||||
Although each component in this repository can be used by itself, the fastest
|
||||
way to get started is by using the `StableDiffusion` class from the `diffusion`
|
||||
way to get started is by using the `StableDiffusion` class from the `stable_diffusion`
|
||||
module.
|
||||
|
||||
```python
|
||||
|
@ -1,7 +1,7 @@
|
||||
# whisper
|
||||
# Whisper
|
||||
|
||||
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
|
||||
recognition models from Open AI, ranging from 39 million to 1.5 billion
|
||||
recognition models from OpenAI, ranging from 39 million to 1.5 billion
|
||||
parameters[^1].
|
||||
|
||||
### Setup
|
||||
@ -15,7 +15,7 @@ pip install -r requirements.txt
|
||||
Install [`ffmpeg`](https://ffmpeg.org/):
|
||||
|
||||
```
|
||||
# on MacOS using Homebrew (https://brew.sh/)
|
||||
# on macOS using Homebrew (https://brew.sh/)
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
|
@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase):
|
||||
logits = mlx_model(mels, tokens)
|
||||
self.assertEqual(logits.dtype, mx.float16)
|
||||
|
||||
|
||||
def test_decode_lang(self):
|
||||
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||
result = decoding.decode(self.model, self.mels, options)
|
||||
|
@ -112,7 +112,7 @@ class DecodingOptions:
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = {
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
}
|
||||
|
||||
|
||||
@ -166,7 +166,8 @@ def convert(model, rules=None):
|
||||
|
||||
|
||||
def torch_to_mlx(
|
||||
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
|
||||
torch_model: torch_whisper.Whisper,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
) -> whisper.Whisper:
|
||||
def convert_rblock(model, rules):
|
||||
children = dict(model.named_children())
|
||||
@ -194,6 +195,6 @@ def torch_to_mlx(
|
||||
def load_model(
|
||||
name: str,
|
||||
download_root: str = None,
|
||||
dtype : mx.Dtype = mx.float32,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
return torch_to_mlx(load_torch_model(name, download_root), dtype)
|
||||
|
@ -43,7 +43,7 @@ class ModelHolder:
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str, dtype : mx.Dtype):
|
||||
def get_model(cls, model: str, dtype: mx.Dtype):
|
||||
if cls.model is None or model != cls.model_name:
|
||||
cls.model = load_model(model, dtype=dtype)
|
||||
cls.model_name = model
|
||||
|
@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
@ -117,13 +118,19 @@ class ResidualAttentionBlock(nn.Module):
|
||||
if self.cross_attn:
|
||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||
x += y
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||
return x, (kv, cross_kv)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
||||
self,
|
||||
n_mels: int,
|
||||
n_ctx: int,
|
||||
n_state: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
@ -134,8 +141,8 @@ class AudioEncoder(nn.Module):
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.gelu(self.conv1(x))
|
||||
x = nn.gelu(self.conv2(x))
|
||||
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||
x = nn.gelu(self.conv2(x)).astype(x.dtype)
|
||||
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
||||
x = x + self._positional_embedding
|
||||
|
||||
@ -148,7 +155,13 @@ class AudioEncoder(nn.Module):
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
||||
self,
|
||||
n_vocab: int,
|
||||
n_ctx: int,
|
||||
n_state: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -160,7 +173,9 @@ class TextDecoder(nn.Module):
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
self.ln = LayerNorm(n_state)
|
||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
|
||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
|
||||
dtype
|
||||
)
|
||||
|
||||
def __call__(self, x, xa, kv_cache=None):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user