From 5d4838b02e3f261176b1bd9723a7299b0740a15b Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 12:07:33 -0500 Subject: [PATCH 01/33] Updating BERT model to take advantage of bias param in MultiHeadAttention --- bert/model.py | 90 +++++++++++---------------------------------------- 1 file changed, 18 insertions(+), 72 deletions(-) diff --git a/bert/model.py b/bert/model.py index 446919b1..d4dccfac 100644 --- a/bert/model.py +++ b/bert/model.py @@ -7,7 +7,6 @@ import mlx.core as mx import mlx.nn as nn import argparse import numpy -import math @dataclass @@ -34,74 +33,6 @@ model_configs = { } -class MultiHeadAttention(nn.Module): - """ - Minor update to the MultiHeadAttention module to ensure that the - projections use bias. - """ - - def __init__( - self, - dims: int, - num_heads: int, - query_input_dims: Optional[int] = None, - key_input_dims: Optional[int] = None, - value_input_dims: Optional[int] = None, - value_dims: Optional[int] = None, - value_output_dims: Optional[int] = None, - ): - super().__init__() - - if (dims % num_heads) != 0: - raise ValueError( - f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" - ) - - query_input_dims = query_input_dims or dims - key_input_dims = key_input_dims or dims - value_input_dims = value_input_dims or key_input_dims - value_dims = value_dims or dims - value_output_dims = value_output_dims or dims - - self.num_heads = num_heads - self.query_proj = nn.Linear(query_input_dims, dims, True) - self.key_proj = nn.Linear(key_input_dims, dims, True) - self.value_proj = nn.Linear(value_input_dims, value_dims, True) - self.out_proj = nn.Linear(value_dims, value_output_dims, True) - - def __call__(self, queries, keys, values, mask=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys - if mask is not None: - mask = self.convert_mask_to_additive_causal_mask(mask) - mask = mx.expand_dims(mask, (1, 2)) - mask = mx.broadcast_to(mask, scores.shape) - scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat) - - def convert_mask_to_additive_causal_mask( - self, mask: mx.array, dtype: mx.Dtype = mx.float32 - ) -> mx.array: - mask = mask == 0 - mask = mask.astype(dtype) * -1e9 - return mask - - class TransformerEncoderLayer(nn.Module): """ A transformer encoder layer with (the original BERT) post-normalization. @@ -116,7 +47,7 @@ class TransformerEncoderLayer(nn.Module): ): super().__init__() mlp_dims = mlp_dims or dims * 4 - self.attention = MultiHeadAttention(dims, num_heads) + self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) self.linear1 = nn.Linear(dims, mlp_dims) @@ -186,11 +117,26 @@ class Bert(nn.Module): self, input_ids: mx.array, token_type_ids: mx.array, - attention_mask: Optional[mx.array] = None, + attention_mask: mx.array = None, ) -> tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + y = self.encoder(x, attention_mask) return y, mx.tanh(self.pooler(y[:, 0])) + + + def convert_mask_to_additive_causal_mask( + self, mask: mx.array, dtype: mx.Dtype = mx.float32 + ) -> mx.array: + mask = mask == 0 + mask = mask.astype(dtype) * -1e9 + return mask + def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: @@ -214,7 +160,7 @@ def run(bert_model: str, mlx_model: str): "A second string", "This is another string.", ] - + tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} From a577abc31320f3c03f90dc20201421a375a73ae5 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 21:21:24 -0500 Subject: [PATCH 02/33] Cleaner masking code --- bert/model.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/bert/model.py b/bert/model.py index d4dccfac..4666a78d 100644 --- a/bert/model.py +++ b/bert/model.py @@ -123,20 +123,11 @@ class Bert(nn.Module): if attention_mask is not None: # convert 0's to -infs, 1's to 0's, and make it broadcastable - attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask) + attention_mask = mx.log(attention_mask) attention_mask = mx.expand_dims(attention_mask, (1, 2)) y = self.encoder(x, attention_mask) return y, mx.tanh(self.pooler(y[:, 0])) - - - def convert_mask_to_additive_causal_mask( - self, mask: mx.array, dtype: mx.Dtype = mx.float32 - ) -> mx.array: - mask = mask == 0 - mask = mask.astype(dtype) * -1e9 - return mask - def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: From f37e777243029872d5783f8db1183771d818d59e Mon Sep 17 00:00:00 2001 From: Sarthak Yadav Date: Tue, 12 Dec 2023 19:01:06 +0100 Subject: [PATCH 03/33] added CIFAR10 + ResNet example --- cifar/README.md | 31 ++++++++++ cifar/dataset.py | 39 +++++++++++++ cifar/main.py | 108 ++++++++++++++++++++++++++++++++++ cifar/requirements.txt | 3 + cifar/resnet.py | 129 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 310 insertions(+) create mode 100644 cifar/README.md create mode 100644 cifar/dataset.py create mode 100644 cifar/main.py create mode 100644 cifar/requirements.txt create mode 100644 cifar/resnet.py diff --git a/cifar/README.md b/cifar/README.md new file mode 100644 index 00000000..0d793853 --- /dev/null +++ b/cifar/README.md @@ -0,0 +1,31 @@ +# CIFAR and ResNets + +* This example shows how to run ResNets on CIFAR10 dataset, in accordance with the original [paper](https://arxiv.org/abs/1512.03385). +* Also illustrates how to use `mlx-data` to download and load the dataset. + + +## Pre-requisites +* Install the dependencies: + +``` +pip install -r requirements.txt +``` + +## Running the example +Run the example with: + +``` +python main.py +``` + +By default the example runs on the GPU. To run on the CPU, use: + +``` +python main.py --cpu_only +``` + +For all available options, run: + +``` +python main.py --help +``` diff --git a/cifar/dataset.py b/cifar/dataset.py new file mode 100644 index 00000000..f4a3cd63 --- /dev/null +++ b/cifar/dataset.py @@ -0,0 +1,39 @@ +import mlx.core as mx +from mlx.data.datasets import load_cifar10 +import math + + +def get_cifar10(batch_size, root=None): + + tr = load_cifar10(root=root) + num_tr_samples = tr.size() + + mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) + std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + + tr_iter = ( + tr.shuffle() + .to_stream() + .image_random_h_flip("image", prob=0.5) + .pad("image", 0, 4, 4, 0.0) + .pad("image", 1, 4, 4, 0.0) + .image_random_crop("image", 32, 32) + .key_transform("image", lambda x: (x.astype("float32") / 255.0)) + .key_transform("image", lambda x: (x - mean) / std) + .batch(batch_size) + ) + + test = load_cifar10(root=root, train=False) + num_test_samples = test.size() + + test_iter = ( + test.to_stream() + .key_transform("image", lambda x: (x.astype("float32") / 255.0)) + .key_transform("image", lambda x: (x - mean) / std) + .batch(batch_size) + ) + + num_tr_steps_per_epoch = num_tr_samples // batch_size + num_test_steps_per_epoch = num_test_samples // batch_size + + return tr_iter, test_iter, num_tr_steps_per_epoch, num_test_steps_per_epoch diff --git a/cifar/main.py b/cifar/main.py new file mode 100644 index 00000000..5272733a --- /dev/null +++ b/cifar/main.py @@ -0,0 +1,108 @@ +import argparse +import resnet +import numpy as np +import mlx.nn as nn +import mlx.core as mx +import mlx.optimizers as optim +from dataset import get_cifar10 + + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument( + "--arch", + type=str, + default="resnet20", + help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", +) +parser.add_argument("--batch_size", type=int, default=128, help="batch size") +parser.add_argument("--epochs", type=int, default=100, help="number of epochs") +parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") +parser.add_argument("--seed", type=int, default=0, help="random seed") +parser.add_argument("--cpu_only", action="store_true", help="use cpu only") + + +def loss_fn(model, inp, tgt): + return mx.mean(nn.losses.cross_entropy(model(inp), tgt)) + + +def eval_fn(model, inp, tgt): + return mx.mean(mx.argmax(model(inp), axis=1) == tgt) + + +def train_epoch(model, train_iter, optimizer, epoch): + def train_step(model, inp, tgt): + output = model(inp) + loss = mx.mean(nn.losses.cross_entropy(output, tgt)) + acc = mx.mean(mx.argmax(output, axis=1) == tgt) + return loss, acc + + train_step_fn = nn.value_and_grad(model, train_step) + + losses = [] + accs = [] + + for batch_counter, batch in enumerate(train_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + (loss, acc), grads = train_step_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + + loss_value = loss.item() + acc_value = acc.item() + losses.append(loss_value) + accs.append(acc_value) + + if batch_counter % 10 == 0: + print( + f"Epoch {epoch:02d}[{batch_counter:03d}]: tr_loss {loss_value:.3f}, tr_acc {acc_value:.3f}" + ) + + mean_tr_loss = np.mean(np.array(losses)) + mean_tr_acc = np.mean(np.array(accs)) + return mean_tr_loss, mean_tr_acc + + +def test_epoch(model, test_iter, epoch): + accs = [] + for batch_counter, batch in enumerate(test_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + acc = eval_fn(model, x, y) + acc_value = acc.item() + accs.append(acc_value) + mean_acc = np.mean(np.array(accs)) + + return mean_acc + + +def main(args): + np.random.seed(args.seed) + mx.random.seed(args.seed) + + model = resnet.__dict__[args.arch]() + + print("num_params: {:0.04f} M".format(model.num_params() / 1e6)) + mx.eval(model.parameters()) + + optimizer = optim.Adam(learning_rate=args.lr) + + for epoch in range(args.epochs): + # get data every epoch + # or set .repeat() on the data stream appropriately + train_data, test_data, tr_batches, _ = get_cifar10(args.batch_size) + + epoch_tr_loss, epoch_tr_acc = train_epoch(model, train_data, optimizer, epoch) + print( + f"Epoch {epoch}: avg. tr_loss {epoch_tr_loss:.3f}, avg. tr_acc {epoch_tr_acc:.3f}" + ) + + epoch_test_acc = test_epoch(model, test_data, epoch) + print(f"Epoch {epoch}: Test_acc {epoch_test_acc:.3f}") + + +if __name__ == "__main__": + args = parser.parse_args() + if args.cpu_only: + mx.set_default_device(mx.cpu) + main(args) diff --git a/cifar/requirements.txt b/cifar/requirements.txt new file mode 100644 index 00000000..c4c2e575 --- /dev/null +++ b/cifar/requirements.txt @@ -0,0 +1,3 @@ +mlx +mlx-data +numpy \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py new file mode 100644 index 00000000..3d88397b --- /dev/null +++ b/cifar/resnet.py @@ -0,0 +1,129 @@ +""" +Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385]. +Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. + +There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. + +Authors: + Sarthak Yadav, 2023 +""" + +from typing import Any +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + + +__all__ = [ + "ResNet", + "resnet20", + "resnet32", + "resnet44", + "resnet56", + "resnet110", + "resnet1202", +] + + +class ShortcutA(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def __call__(self, x): + return mx.pad( + x[:, ::2, ::2, :], + pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)], + ) + + +class Block(nn.Module): + expansion = 1 + + def __init__(self, in_dims, dims, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d( + in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.LayerNorm(dims) + + self.conv2 = nn.Conv2d( + dims, dims, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.LayerNorm(dims) + + if stride != 1 or in_dims != dims: + self.shortcut = ShortcutA(dims) + else: + self.shortcut = None + + def __call__(self, x): + + out = nn.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + if self.shortcut is None: + out += x + else: + out += self.shortcut(x) + out = nn.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.LayerNorm(16) + self.in_dims = 16 + + self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + + self.linear = nn.Linear(64, num_classes) + + def _make_layer(self, block, dims, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_dims, dims, stride)) + self.in_dims = dims * block.expansion + return nn.Sequential(*layers) + + def num_params(self): + nparams = sum(x.size for k, x in tree_flatten(self.parameters())) + return nparams + + def __call__(self, x): + x = nn.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1) + x = self.linear(x) + return x + + +def resnet20(**kwargs): + return ResNet(Block, [3, 3, 3], **kwargs) + + +def resnet32(**kwargs): + return ResNet(Block, [5, 5, 5], **kwargs) + + +def resnet44(**kwargs): + return ResNet(Block, [7, 7, 7], **kwargs) + + +def resnet56(**kwargs): + return ResNet(Block, [9, 9, 9], **kwargs) + + +def resnet110(**kwargs): + return ResNet(Block, [18, 18, 18], **kwargs) + + +def resnet1202(**kwargs): + return ResNet(Block, [200, 200, 200], **kwargs) From 2439333a57812cb407ca86c85e8a2d77f0eb9231 Mon Sep 17 00:00:00 2001 From: Sarthak Yadav Date: Tue, 12 Dec 2023 19:07:39 +0100 Subject: [PATCH 04/33] fixed doc for ResNet --- cifar/resnet.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cifar/resnet.py b/cifar/resnet.py index 3d88397b..b89a612b 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -39,6 +39,10 @@ class ShortcutA(nn.Module): class Block(nn.Module): expansion = 1 + """ + Implements a ResNet block with two convolutional layers and a skip connection. + As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) + """ def __init__(self, in_dims, dims, stride=1): super().__init__() @@ -71,6 +75,10 @@ class Block(nn.Module): class ResNet(nn.Module): + """ + Creates a ResNet model for CIFAR-10, as specified in the original paper. + """ + def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) From 5515c2a75b2b3210df25abd585d917766513cca0 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Wed, 13 Dec 2023 10:12:10 +0100 Subject: [PATCH 05/33] fix "request access" form url for Llama models --- lora/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora/README.md b/lora/README.md index 09911bd3..a6819950 100644 --- a/lora/README.md +++ b/lora/README.md @@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar ``` If you do not have access to the Llama weights you will need to [request -access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) from Meta. Convert the model with: From 4b1a06c0cbf8641443b15d9164fb4113997c937c Mon Sep 17 00:00:00 2001 From: bofenghuang Date: Wed, 13 Dec 2023 11:07:47 +0100 Subject: [PATCH 06/33] Fix fp16 --- whisper/whisper/whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 62e43de3..bca69946 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module): if self.cross_attn: y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) x += y - x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) + x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) return x, (kv, cross_kv) @@ -134,8 +134,8 @@ class AudioEncoder(nn.Module): self.ln_post = LayerNorm(n_state) def __call__(self, x): - x = nn.gelu(self.conv1(x)) - x = nn.gelu(self.conv2(x)) + x = nn.gelu(self.conv1(x)).astype(x.dtype) + x = nn.gelu(self.conv2(x)).astype(x.dtype) assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" x = x + self._positional_embedding From 03fe6896de09283a3eec55bcff5a9cc038217693 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 13 Dec 2023 11:37:02 -0500 Subject: [PATCH 07/33] Fix convert.py instructions for Bert model It just adds the missing backslash. --- bert/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert/README.md b/bert/README.md index cea738df..70bc39a9 100644 --- a/bert/README.md +++ b/bert/README.md @@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex ``` python convert.py \ - --bert-model bert-base-uncased + --bert-model bert-base-uncased \ --mlx-model weights/bert-base-uncased.npz ``` From 1505e49a6254df2df012b06f6489ff5016e8b8dd Mon Sep 17 00:00:00 2001 From: jbax3 <61852880+jbax3@users.noreply.github.com> Date: Wed, 13 Dec 2023 15:51:27 -0600 Subject: [PATCH 08/33] Update README.md to fix git-lfs command --- mixtral/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mixtral/README.md b/mixtral/README.md index 417759e1..b56ee767 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -17,7 +17,7 @@ brew install git-lfs Download the models from Hugging Face: ``` -git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen +git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen ``` After that's done, combine the files: From 9f4e63acbf57637ed13e04f2c1d5c0f627df3e41 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Wed, 13 Dec 2023 17:48:07 -0500 Subject: [PATCH 09/33] Update to mlx>=0.0.5 --- bert/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bert/requirements.txt b/bert/requirements.txt index 24266334..a6b564c5 100644 --- a/bert/requirements.txt +++ b/bert/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.0.5 transformers -numpy \ No newline at end of file +numpy From cbae83e011431e983061322e38a0dd2d691140b7 Mon Sep 17 00:00:00 2001 From: "Stv.X" Date: Thu, 14 Dec 2023 08:15:26 +0800 Subject: [PATCH 10/33] Corrected spelling of terms in whisper/README.md --- whisper/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/whisper/README.md b/whisper/README.md index 602941a0..7df1382f 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -1,7 +1,7 @@ -# whisper +# Whisper Speech recognition with Whisper in MLX. Whisper is a set of open source speech -recognition models from Open AI, ranging from 39 million to 1.5 billion +recognition models from OpenAI, ranging from 39 million to 1.5 billion parameters[^1]. ### Setup @@ -15,7 +15,7 @@ pip install -r requirements.txt Install [`ffmpeg`](https://ffmpeg.org/): ``` -# on MacOS using Homebrew (https://brew.sh/) +# on macOS using Homebrew (https://brew.sh/) brew install ffmpeg ``` From a466cc51917270308cc6376909dd8d7a3598cf85 Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Wed, 13 Dec 2023 22:22:56 -0500 Subject: [PATCH 11/33] phi-2 draft --- phi2/README.md | 24 +++++ phi2/__init__.py | 0 phi2/convert.py | 67 ++++++++++++ phi2/hf_model.py | 23 +++++ phi2/model.py | 232 ++++++++++++++++++++++++++++++++++++++++++ phi2/phi2_outputs.txt | 63 ++++++++++++ 6 files changed, 409 insertions(+) create mode 100644 phi2/README.md create mode 100644 phi2/__init__.py create mode 100644 phi2/convert.py create mode 100644 phi2/hf_model.py create mode 100644 phi2/model.py create mode 100644 phi2/phi2_outputs.txt diff --git a/phi2/README.md b/phi2/README.md new file mode 100644 index 00000000..c38f8a74 --- /dev/null +++ b/phi2/README.md @@ -0,0 +1,24 @@ +# Phi-2 + +Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text. +Its performance theoretically rivals much, much stronger models. + +## Downloading and Converting Weights + +To download and convert the model: + +```sh +python phi2/convert.py +``` + +That will fill in `weights/phi-2.npz`. + +## Running the Model + +🚧 (Not yet done) To run the model: + +```sh +python phi2/generate.py +``` + +Layer-by-layer forward pass outputs are currently shown in the outputs.txt files. diff --git a/phi2/__init__.py b/phi2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/phi2/convert.py b/phi2/convert.py new file mode 100644 index 00000000..cd2f77aa --- /dev/null +++ b/phi2/convert.py @@ -0,0 +1,67 @@ +from transformers import AutoModelForCausalLM + +import numpy + + +def split_attention_matrix(state_dict, key) -> dict: + # "transformer.h.0.mixer" + _, model_dim = state_dict[key + ".weight"].shape + # (3 * model_dim, model_dim) + Wqkv_weight_key = key + ".weight" + Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] + Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] + Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] + + # (3 * model_dim) + Wqkv_bias_key = key + ".bias" + Wq_bias = state_dict[Wqkv_bias_key][:model_dim] + Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] + Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] + + out_key = key.replace("mixer.Wqkv", "self_attention") + + return { + out_key + ".query_proj.weight": Wq_weight, + out_key + ".query_proj.bias": Wq_bias, + out_key + ".key_proj.weight": Wk_weight, + out_key + ".key_proj.bias": Wk_bias, + out_key + ".value_proj.weight": Wv_weight, + out_key + ".value_proj.bias": Wv_bias, + } + + +def replace_key(key: str) -> str: + if "wte.weight" in key: + key = "wte.weight" + + if ".mlp" in key: + key = key.replace(".mlp", "") + + if ".mixer.out_proj" in key: + key = key.replace(".mixer", ".self_attention") + + return key + + +def convert(): + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + state_dict = model.state_dict() + keys = list(state_dict.keys()) + + for key in keys: + if ".mixer.Wqkv.weight" not in key: + continue + key_stub = key.rstrip(".weight") + state_dict.update(split_attention_matrix(state_dict, key_stub)) + + del state_dict[key_stub + ".weight"] + del state_dict[key_stub + ".bias"] + + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + numpy.savez("weights/phi-2.npz", **weights) + + +if __name__ == "__main__": + convert() diff --git a/phi2/hf_model.py b/phi2/hf_model.py new file mode 100644 index 00000000..d09ff108 --- /dev/null +++ b/phi2/hf_model.py @@ -0,0 +1,23 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if __name__ == "__main__": + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + + inputs = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="pt", + return_attention_mask=False, + ) + + print(model(**inputs)) + + # outputs = model.generate(**inputs, max_length=200) + # text = tokenizer.batch_decode(outputs)[0] + # print(text) diff --git a/phi2/model.py b/phi2/model.py new file mode 100644 index 00000000..991bf193 --- /dev/null +++ b/phi2/model.py @@ -0,0 +1,232 @@ +from typing import Optional +from dataclasses import dataclass +from mlx.utils import tree_unflatten, tree_map +from transformers import AutoTokenizer + +import mlx.core as mx +import mlx.nn as nn +import math + + +@dataclass +class ModelArgs: + max_sequence_length: int = 2048 + num_vocab: int = 51200 + model_dim: int = 2560 + num_heads: int = 32 + num_layers: int = 32 + rotary_dim: int = 32 + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __call__(self, input: mx.array) -> mx.array: + return ( + 0.5 + * input + * ( + 1.0 + + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) + ) + ) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, num_heads: int, bias: bool = True): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(dims // num_heads, traditional=True) + self.query_proj = nn.Linear(dims, dims, bias=bias) + self.key_proj = nn.Linear(dims, dims, bias=bias) + self.value_proj = nn.Linear(dims, dims, bias=bias) + self.out_proj = nn.Linear(dims, dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + # Note that we return the keys and values to possibly be used as a cache + return self.out_proj(values_hat), (keys, values) + + +class ParallelBlock(nn.Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.self_attention = RoPEAttention(dims, num_heads, bias=True) + self.ln = nn.LayerNorm(dims) + self.fc1 = nn.Linear(dims, mlp_dims) + self.fc2 = nn.Linear(mlp_dims, dims) + self.act = NewGELUActivation() + + def __call__(self, x, x_mask): + residual = x + hidden_states = self.ln(x) + attn_outputs, _ = self.self_attention( + hidden_states, hidden_states, hidden_states, x_mask + ) + ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) + + hidden_states = attn_outputs + ff_hidden_states + residual + + return hidden_states + + +class TransformerDecoder(nn.Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + + def __call__(self, x, x_mask): + for layer in self.h: + x = layer(x, x_mask) + return x + + +class Phi2(nn.Module): + def __init__(self, config: ModelArgs): + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + self.transformer = TransformerDecoder( + num_layers=config.num_layers, + dims=config.model_dim, + num_heads=config.num_heads, + ) + + self.lm_head = LanguageModelingHead(config) + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array = None, + ) -> tuple[mx.array, mx.array]: + x = self.wte(input_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = mx.log(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + else: + attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( + x.shape[1] + ) + + y = self.transformer(x, attention_mask) + return self.lm_head(y) + + def generate(self, input_ids, temp=1.0): + cache = input_ids.tolist() + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) + mask = mask.astype(self.wte.weight.dtype) + + # First we process the prompt x the same way as in __call__ but + # save the caches in cache + x = self.wte(input_ids) + # for l in self.layers: + # x, c = l(x, mask=mask) + # cache.append(c) # <--- we store the per layer cache in a + # simple python list + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) # <--- we only care about the last logits + # that generate the next token + y = mx.random.categorical(y * (1 / temp)) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + cache += [y.item()] + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = self.wte(mx.array(cache)) + x = self.transformer(x, mask) + y = self.lm_head(x[:, -1]) + y = mx.random.categorical(y * (1 / temp)) + cache += [y[0].item()] + + yield y + + +class LanguageModelingHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) + + +if __name__ == "__main__": + model = Phi2(ModelArgs()) + + weights = mx.load("weights/phi-2.npz") + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: mx.array(p), weights) + + model.update(weights) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + tokens = tokenizer( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''', + return_tensors="np", + return_attention_mask=False, + ) + + tokens = {key: mx.array(v) for key, v in tokens.items()} + + print( + '''def print_prime(n): + """ + Print all primes between 1 and n + """''' + ) + for output in model.generate(**tokens): + print(tokenizer.decode(output.item())) diff --git a/phi2/phi2_outputs.txt b/phi2/phi2_outputs.txt new file mode 100644 index 00000000..4f27e44b --- /dev/null +++ b/phi2/phi2_outputs.txt @@ -0,0 +1,63 @@ +(HF) Output of Embeddings + +tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075], + [-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058], + [-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008], + ..., + [ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049], + [-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146], + [-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]], + device='cuda:0', dtype=torch.float16, grad_fn=) + +(MLX) Output of Embeddings + +array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444], + [-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504], + [-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437], + ..., + [0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477], + [-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721], + [-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16) + +(HF) Output of First Attention Layer + +tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469], + [ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278], + [ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880], + ..., + [ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255], + [ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064], + [-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]], + device='cuda:0', dtype=torch.float16, grad_fn=) +torch.Size([1, 23, 2560]) + +(MLX) Output of First Attention Layer + +array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074], + [0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176], + [-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785], + ..., + [0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505], + [0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316], + [-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32) +[1, 23, 2560] + +(HF) Overall Output of Inputs: + +tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043], + [ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866], + [ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696], + ..., + [10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754], + [10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104], + [10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]], + +(MLX) Overall Output of Inputs: + +array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392], + [4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547], + [4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838], + ..., + [6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666], + [7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937], + [8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32) \ No newline at end of file From 0ce7618bc91f134cc209f12da5997e9298592d47 Mon Sep 17 00:00:00 2001 From: Nolan Date: Wed, 13 Dec 2023 20:51:39 -0800 Subject: [PATCH 12/33] Fix typo in stable_diffusion README --- stable_diffusion/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index 400a50f7..5e44cb1a 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -27,7 +27,7 @@ Usage ------ Although each component in this repository can be used by itself, the fastest -way to get started is by using the `StableDiffusion` class from the `diffusion` +way to get started is by using the `StableDiffusion` class from the `stable_diffusion` module. ```python From 88d7b67e6e8dee7a2c128d69223ac0f551aab7a6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 13 Dec 2023 22:26:33 -0800 Subject: [PATCH 13/33] add cache + generation, clean up some stuff --- phi2/.gitignore | 1 + phi2/convert.py | 2 +- phi2/model.py | 177 ++++++++++++++++-------------------------- phi2/requirements.txt | 3 + 4 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 phi2/.gitignore create mode 100644 phi2/requirements.txt diff --git a/phi2/.gitignore b/phi2/.gitignore new file mode 100644 index 00000000..258ec872 --- /dev/null +++ b/phi2/.gitignore @@ -0,0 +1 @@ +weights.npz diff --git a/phi2/convert.py b/phi2/convert.py index cd2f77aa..3c821f69 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -60,7 +60,7 @@ def convert(): del state_dict[key_stub + ".bias"] weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights/phi-2.npz", **weights) + numpy.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/model.py b/phi2/model.py index 991bf193..5253a266 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -7,7 +7,6 @@ import mlx.core as mx import mlx.nn as nn import math - @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -18,23 +17,6 @@ class ModelArgs: rotary_dim: int = 32 -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def __call__(self, input: mx.array) -> mx.array: - return ( - 0.5 - * input - * ( - 1.0 - + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) - ) - ) - - class RoPEAttention(nn.Module): def __init__(self, dims: int, num_heads: int, bias: bool = True): super().__init__() @@ -77,6 +59,7 @@ class RoPEAttention(nn.Module): scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask + scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -92,19 +75,13 @@ class ParallelBlock(nn.Module): self.ln = nn.LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) - self.act = NewGELUActivation() + self.act = nn.GELU(approx="precise") - def __call__(self, x, x_mask): - residual = x - hidden_states = self.ln(x) - attn_outputs, _ = self.self_attention( - hidden_states, hidden_states, hidden_states, x_mask - ) - ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) - - hidden_states = attn_outputs + ff_hidden_states + residual - - return hidden_states + def __call__(self, x, mask, cache): + h = self.ln(x) + attn_h, cache = self.self_attention(h, h, h, mask, cache) + ff_h = self.fc2(self.act(self.fc1(h))) + return attn_h + ff_h + x, cache class TransformerDecoder(nn.Module): @@ -114,10 +91,22 @@ class TransformerDecoder(nn.Module): super().__init__() self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] - def __call__(self, x, x_mask): - for layer in self.h: - x = layer(x, x_mask) - return x + def __call__(self, x, mask, cache): + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) class Phi2(nn.Module): @@ -128,77 +117,40 @@ class Phi2(nn.Module): dims=config.model_dim, num_heads=config.num_heads, ) - - self.lm_head = LanguageModelingHead(config) + self.lm_head = OutputHead(config) def __call__( self, - input_ids: mx.array, - attention_mask: mx.array = None, + inputs: mx.array, + mask: mx.array = None, + cache: mx.array = None, ) -> tuple[mx.array, mx.array]: - x = self.wte(input_ids) + x = self.wte(inputs) - if attention_mask is not None: - # convert 0's to -infs, 1's to 0's, and make it broadcastable - attention_mask = mx.log(attention_mask) - attention_mask = mx.expand_dims(attention_mask, (1, 2)) + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache + + +def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) else: - attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( - x.shape[1] - ) + return mx.random.categorical(logits * (1 / temp)) - y = self.transformer(x, attention_mask) - return self.lm_head(y) + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y - def generate(self, input_ids, temp=1.0): - cache = input_ids.tolist() - - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) - mask = mask.astype(self.wte.weight.dtype) - - # First we process the prompt x the same way as in __call__ but - # save the caches in cache - x = self.wte(input_ids) - # for l in self.layers: - # x, c = l(x, mask=mask) - # cache.append(c) # <--- we store the per layer cache in a - # simple python list - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) # <--- we only care about the last logits - # that generate the next token - y = mx.random.categorical(y * (1 / temp)) - - # y now has size [1] - # Since MLX is lazily evaluated nothing is computed yet. - # Calling y.item() would force the computation to happen at - # this point but we can also choose not to do that and let the - # user choose when to start the computation. + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) yield y - cache += [y.item()] - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = self.wte(mx.array(cache)) - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) - cache += [y[0].item()] - - yield y - - -class LanguageModelingHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) if __name__ == "__main__": @@ -206,27 +158,28 @@ if __name__ == "__main__": weights = mx.load("weights/phi-2.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p), weights) + weights = tree_map(lambda p: mx.array(p, mx.float32), weights) model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - tokens = tokenizer( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''', + prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", return_tensors="np", return_attention_mask=False, - ) + )["input_ids"] - tokens = {key: mx.array(v) for key, v in tokens.items()} + prompt = mx.array(prompt) + + tokens_per_eval = 1 + max_tokens = 100 + + tokens = [] + for token, _ in zip(generate(prompt, model), range(max_tokens)): + tokens.append(token) + + if (len(tokens) % tokens_per_eval) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] - print( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''' - ) - for output in model.generate(**tokens): - print(tokenizer.decode(output.item())) diff --git a/phi2/requirements.txt b/phi2/requirements.txt new file mode 100644 index 00000000..6a11f8d2 --- /dev/null +++ b/phi2/requirements.txt @@ -0,0 +1,3 @@ +einops +mlx +transformers From 15a6c155a815266c937d14402e8cf2608796aa76 Mon Sep 17 00:00:00 2001 From: Sarthak Yadav Date: Thu, 14 Dec 2023 09:05:04 +0100 Subject: [PATCH 14/33] simplified ResNet, expanded README with throughput and performance --- cifar/README.md | 28 +++++++++++++++++++++++---- cifar/dataset.py | 2 +- cifar/main.py | 43 ++++++++++++++++++++++-------------------- cifar/requirements.txt | 3 +-- cifar/resnet.py | 16 +++++++--------- 5 files changed, 56 insertions(+), 36 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index 0d793853..abb2c0f5 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -1,11 +1,10 @@ # CIFAR and ResNets -* This example shows how to run ResNets on CIFAR10 dataset, in accordance with the original [paper](https://arxiv.org/abs/1512.03385). -* Also illustrates how to use `mlx-data` to download and load the dataset. +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset. ## Pre-requisites -* Install the dependencies: +Install the dependencies: ``` pip install -r requirements.txt @@ -21,7 +20,7 @@ python main.py By default the example runs on the GPU. To run on the CPU, use: ``` -python main.py --cpu_only +python main.py --cpu ``` For all available options, run: @@ -29,3 +28,24 @@ For all available options, run: ``` python main.py --help ``` + + +## Throughput + +On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`: +``` +Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec +``` + +When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!): +``` +Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec +``` + +## Results +After training for 100 epochs, the following results were observed: +``` +Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec +Epoch: 99 | test_acc 0.807 +``` +At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added. \ No newline at end of file diff --git a/cifar/dataset.py b/cifar/dataset.py index f4a3cd63..29f558d1 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -36,4 +36,4 @@ def get_cifar10(batch_size, root=None): num_tr_steps_per_epoch = num_tr_samples // batch_size num_test_steps_per_epoch = num_test_samples // batch_size - return tr_iter, test_iter, num_tr_steps_per_epoch, num_test_steps_per_epoch + return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 5272733a..29b0cbc7 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -1,6 +1,6 @@ import argparse +import time import resnet -import numpy as np import mlx.nn as nn import mlx.core as mx import mlx.optimizers as optim @@ -14,11 +14,11 @@ parser.add_argument( default="resnet20", help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", ) -parser.add_argument("--batch_size", type=int, default=128, help="batch size") +parser.add_argument("--batch_size", type=int, default=256, help="batch size") parser.add_argument("--epochs", type=int, default=100, help="number of epochs") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") parser.add_argument("--seed", type=int, default=0, help="random seed") -parser.add_argument("--cpu_only", action="store_true", help="use cpu only") +parser.add_argument("--cpu", action="store_true", help="use cpu only") def loss_fn(model, inp, tgt): @@ -40,27 +40,30 @@ def train_epoch(model, train_iter, optimizer, epoch): losses = [] accs = [] + samples_per_sec = [] for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) + tic = time.perf_counter() (loss, acc), grads = train_step_fn(model, x, y) optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) - + toc = time.perf_counter() loss_value = loss.item() acc_value = acc.item() losses.append(loss_value) accs.append(acc_value) - + samples_per_sec.append(x.shape[0] / (toc - tic)) if batch_counter % 10 == 0: print( - f"Epoch {epoch:02d}[{batch_counter:03d}]: tr_loss {loss_value:.3f}, tr_acc {acc_value:.3f}" + f"Epoch {epoch:02d} [{batch_counter:03d}] | tr_loss {loss_value:.3f} | tr_acc {acc_value:.3f} | Throughput: {x.shape[0] / (toc - tic):.2f} images/second" ) - mean_tr_loss = np.mean(np.array(losses)) - mean_tr_acc = np.mean(np.array(accs)) - return mean_tr_loss, mean_tr_acc + mean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_acc = mx.mean(mx.array(accs)) + samples_per_sec = mx.mean(mx.array(samples_per_sec)) + return mean_tr_loss, mean_tr_acc, samples_per_sec def test_epoch(model, test_iter, epoch): @@ -71,13 +74,11 @@ def test_epoch(model, test_iter, epoch): acc = eval_fn(model, x, y) acc_value = acc.item() accs.append(acc_value) - mean_acc = np.mean(np.array(accs)) - + mean_acc = mx.mean(mx.array(accs)) return mean_acc def main(args): - np.random.seed(args.seed) mx.random.seed(args.seed) model = resnet.__dict__[args.arch]() @@ -87,22 +88,24 @@ def main(args): optimizer = optim.Adam(learning_rate=args.lr) + train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): - # get data every epoch - # or set .repeat() on the data stream appropriately - train_data, test_data, tr_batches, _ = get_cifar10(args.batch_size) - - epoch_tr_loss, epoch_tr_acc = train_epoch(model, train_data, optimizer, epoch) + epoch_tr_loss, epoch_tr_acc, train_throughput = train_epoch( + model, train_data, optimizer, epoch + ) print( - f"Epoch {epoch}: avg. tr_loss {epoch_tr_loss:.3f}, avg. tr_acc {epoch_tr_acc:.3f}" + f"Epoch: {epoch} | avg. tr_loss {epoch_tr_loss.item():.3f} | avg. tr_acc {epoch_tr_acc.item():.3f} | Train Throughput: {train_throughput.item():.2f} images/sec" ) epoch_test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch {epoch}: Test_acc {epoch_test_acc:.3f}") + print(f"Epoch: {epoch} | test_acc {epoch_test_acc.item():.3f}") + + train_data.reset() + test_data.reset() if __name__ == "__main__": args = parser.parse_args() - if args.cpu_only: + if args.cpu: mx.set_default_device(mx.cpu) main(args) diff --git a/cifar/requirements.txt b/cifar/requirements.txt index c4c2e575..6ff78a64 100644 --- a/cifar/requirements.txt +++ b/cifar/requirements.txt @@ -1,3 +1,2 @@ mlx -mlx-data -numpy \ No newline at end of file +mlx-data \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py index b89a612b..6eeadda6 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -38,7 +38,6 @@ class ShortcutA(nn.Module): class Block(nn.Module): - expansion = 1 """ Implements a ResNet block with two convolutional layers and a skip connection. As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) @@ -57,7 +56,7 @@ class Block(nn.Module): ) self.bn2 = nn.LayerNorm(dims) - if stride != 1 or in_dims != dims: + if stride != 1: self.shortcut = ShortcutA(dims) else: self.shortcut = None @@ -83,20 +82,19 @@ class ResNet(nn.Module): super().__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.LayerNorm(16) - self.in_dims = 16 - self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2) self.linear = nn.Linear(64, num_classes) - def _make_layer(self, block, dims, num_blocks, stride): + def _make_layer(self, block, in_dims, dims, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_dims, dims, stride)) - self.in_dims = dims * block.expansion + layers.append(block(in_dims, dims, stride)) + in_dims = dims return nn.Sequential(*layers) def num_params(self): From f691e00e5a5d84c196abd093cb6f8db88ccff6df Mon Sep 17 00:00:00 2001 From: Burak Budanur Date: Thu, 14 Dec 2023 14:02:11 +0100 Subject: [PATCH 15/33] Corrected the typo in 'ffn_dim_multiplier' in and added 'rope_theta' to the list unused. Without these, llama examples did not run. --- llama/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/llama.py b/llama/llama.py index db9c8db3..9b8157b7 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -315,7 +315,7 @@ def load_model(model_path): config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] if config.get("vocab_size", -1) < 0: config["vocab_size"] = weights["output.weight"].shape[-1] - unused = ["multiple_of", "ffn_dim_multiplie"] + unused = ["multiple_of", "ffn_dim_multiplier", 'rope_theta'] for k in unused: if k in config: config.pop(k) From 29b7a973421222e52d182c24564c611955dcdfe4 Mon Sep 17 00:00:00 2001 From: Sarthak Yadav Date: Thu, 14 Dec 2023 16:28:00 +0100 Subject: [PATCH 16/33] updated header --- cifar/resnet.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cifar/resnet.py b/cifar/resnet.py index 6eeadda6..22b8a31a 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -3,9 +3,6 @@ Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv. Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. - -Authors: - Sarthak Yadav, 2023 """ from typing import Any From a8d41491472ffb67081f32ecc4853de7ba1c367c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:08:28 -0800 Subject: [PATCH 17/33] fix fp16 + nits --- phi2/model.py | 97 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index 5253a266..52bda27e 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -1,12 +1,14 @@ +import argparse from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_unflatten, tree_map +from mlx.utils import tree_unflatten from transformers import AutoTokenizer import mlx.core as mx import mlx.nn as nn import math + @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -17,17 +19,22 @@ class ModelArgs: rotary_dim: int = 32 +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, bias: bool = True): + def __init__(self, dims: int, num_heads: int, rotary_dim: int): super().__init__() self.num_heads = num_heads - self.rope = nn.RoPE(dims // num_heads, traditional=True) - self.query_proj = nn.Linear(dims, dims, bias=bias) - self.key_proj = nn.Linear(dims, dims, bias=bias) - self.value_proj = nn.Linear(dims, dims, bias=bias) - self.out_proj = nn.Linear(dims, dims, bias=bias) + self.rope = nn.RoPE(rotary_dim, traditional=False) + self.query_proj = nn.Linear(dims, dims) + self.key_proj = nn.Linear(dims, dims) + self.value_proj = nn.Linear(dims, dims) + self.out_proj = nn.Linear(dims, dims) def __call__(self, queries, keys, values, mask=None, cache=None): queries = self.query_proj(queries) @@ -54,25 +61,28 @@ class RoPEAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask - scores = mx.softmax(scores, axis=-1) + scores = mx.softmax(scores, axis=-1).astype(values.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - # Note that we return the keys and values to possibly be used as a cache return self.out_proj(values_hat), (keys, values) class ParallelBlock(nn.Module): - def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + def __init__(self, config: ModelArgs): super().__init__() - mlp_dims = mlp_dims or dims * 4 - self.self_attention = RoPEAttention(dims, num_heads, bias=True) - self.ln = nn.LayerNorm(dims) + dims = config.model_dim + mlp_dims = dims * 4 + self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) self.act = nn.GELU(approx="precise") @@ -85,11 +95,9 @@ class ParallelBlock(nn.Module): class TransformerDecoder(nn.Module): - def __init__( - self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None - ): + def __init__(self, config: ModelArgs): super().__init__() - self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] + self.h = [ParallelBlock(config) for i in range(config.num_layers)] def __call__(self, x, mask, cache): if cache is None: @@ -102,7 +110,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) + self.ln = LayerNorm(config.model_dim) self.linear = nn.Linear(config.model_dim, config.num_vocab) def __call__(self, inputs): @@ -112,11 +120,7 @@ class OutputHead(nn.Module): class Phi2(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.num_vocab, config.model_dim) - self.transformer = TransformerDecoder( - num_layers=config.num_layers, - dims=config.model_dim, - num_heads=config.num_heads, - ) + self.transformer = TransformerDecoder(config) self.lm_head = OutputHead(config) def __call__( @@ -153,33 +157,58 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): yield y -if __name__ == "__main__": +def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights/phi-2.npz") + weights = mx.load("weights.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p, mx.float32), weights) - model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model() + + prompt = tokenizer( + args.prompt, return_tensors="np", return_attention_mask=False, )["input_ids"] prompt = mx.array(prompt) - tokens_per_eval = 1 - max_tokens = 100 - tokens = [] - for token, _ in zip(generate(prompt, model), range(max_tokens)): + for token, _ in zip(generate(prompt, model), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % tokens_per_eval) == 0: + if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] - From 1613e608a90c80d96055acb0455258235cd31d3a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:18:01 -0800 Subject: [PATCH 18/33] fix args, update README, remove extra files --- phi2/README.md | 44 +++++++++++++++++++++++------- phi2/hf_model.py | 23 ---------------- phi2/model.py | 5 +++- phi2/phi2_outputs.txt | 63 ------------------------------------------- 4 files changed, 38 insertions(+), 97 deletions(-) delete mode 100644 phi2/hf_model.py delete mode 100644 phi2/phi2_outputs.txt diff --git a/phi2/README.md b/phi2/README.md index c38f8a74..46a7c589 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,24 +1,48 @@ # Phi-2 -Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text. -Its performance theoretically rivals much, much stronger models. +Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture +of GPT-4 outputs and clean web-text. Its performance rivals +much, much stronger models. -## Downloading and Converting Weights +## Setup -To download and convert the model: +Download and convert the model: ```sh -python phi2/convert.py +python convert.py ``` -That will fill in `weights/phi-2.npz`. +which will make a file `weights.npz`. -## Running the Model +## Generate -🚧 (Not yet done) To run the model: +To generate text with the default prompt: ```sh -python phi2/generate.py +python model.py ``` -Layer-by-layer forward pass outputs are currently shown in the outputs.txt files. +Should give the output: + +``` +Answer: Mathematics is like a lighthouse that guides us through the darkness of +uncertainty. Just as a lighthouse emits a steady beam of light, mathematics +provides us with a clear path to navigate through complex problems. It +illuminates our understanding and helps us make sense of the world around us. + +Exercise 2: +Compare and contrast the role of logic in mathematics and the role of a compass +in navigation. + +Answer: Logic in mathematics is like a compass in navigation. It helps +``` + +To use your own prompt: + +```sh +python model.py --prompt --max_tokens +``` + +[^1]: For more details on the model see the [blog post]( +https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) +and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) diff --git a/phi2/hf_model.py b/phi2/hf_model.py deleted file mode 100644 index d09ff108..00000000 --- a/phi2/hf_model.py +++ /dev/null @@ -1,23 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer - - -if __name__ == "__main__": - model = AutoModelForCausalLM.from_pretrained( - "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True - ) - tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - - inputs = tokenizer( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''', - return_tensors="pt", - return_attention_mask=False, - ) - - print(model(**inputs)) - - # outputs = model.generate(**inputs, max_length=200) - # text = tokenizer.batch_decode(outputs)[0] - # print(text) diff --git a/phi2/model.py b/phi2/model.py index 52bda27e..a99d3d5d 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -203,11 +203,14 @@ if __name__ == "__main__": prompt = mx.array(prompt) + print("[INFO] Generating with Phi-2...", flush=True) + print(args.prompt, end="", flush=True) + tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % args.tokens_per_eval) == 0: + if (len(tokens) % 10) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) diff --git a/phi2/phi2_outputs.txt b/phi2/phi2_outputs.txt deleted file mode 100644 index 4f27e44b..00000000 --- a/phi2/phi2_outputs.txt +++ /dev/null @@ -1,63 +0,0 @@ -(HF) Output of Embeddings - -tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075], - [-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058], - [-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008], - ..., - [ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049], - [-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146], - [-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]], - device='cuda:0', dtype=torch.float16, grad_fn=) - -(MLX) Output of Embeddings - -array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444], - [-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504], - [-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437], - ..., - [0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477], - [-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721], - [-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16) - -(HF) Output of First Attention Layer - -tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469], - [ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278], - [ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880], - ..., - [ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255], - [ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064], - [-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]], - device='cuda:0', dtype=torch.float16, grad_fn=) -torch.Size([1, 23, 2560]) - -(MLX) Output of First Attention Layer - -array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074], - [0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176], - [-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785], - ..., - [0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505], - [0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316], - [-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32) -[1, 23, 2560] - -(HF) Overall Output of Inputs: - -tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043], - [ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866], - [ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696], - ..., - [10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754], - [10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104], - [10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]], - -(MLX) Overall Output of Inputs: - -array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392], - [4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547], - [4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838], - ..., - [6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666], - [7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937], - [8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32) \ No newline at end of file From 840c0c36c29baec53449100883183789310a2ae1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:27:44 -0800 Subject: [PATCH 19/33] don't drop last tokens --- phi2/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index a99d3d5d..38199c6c 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -159,11 +159,8 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights.npz") - weights = tree_unflatten(list(weights.items())) - model.update(weights) - + model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -215,3 +212,7 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) From 3d2a23184a3530fa277067148a811b759675e6d8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:34:24 -0800 Subject: [PATCH 20/33] change file name for consistency, update readme. --- phi2/README.md | 17 ++++++++++++----- phi2/{model.py => phi2.py} | 0 2 files changed, 12 insertions(+), 5 deletions(-) rename phi2/{model.py => phi2.py} (100%) diff --git a/phi2/README.md b/phi2/README.md index 46a7c589..aef47cd1 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,8 +1,9 @@ # Phi-2 Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture -of GPT-4 outputs and clean web-text. Its performance rivals -much, much stronger models. +of GPT-4 outputs and clean web-text. Its performance rivals much larger models. + +Phi-2 efficiently runs on an Apple silicon device with 8 GB memory in 16-bit precision. ## Setup @@ -12,14 +13,14 @@ Download and convert the model: python convert.py ``` -which will make a file `weights.npz`. +This will make the `weights.npz` file which MLX can read. ## Generate To generate text with the default prompt: ```sh -python model.py +python phi2.py ``` Should give the output: @@ -40,7 +41,13 @@ Answer: Logic in mathematics is like a compass in navigation. It helps To use your own prompt: ```sh -python model.py --prompt --max_tokens +python phi2.py --prompt --max_tokens +``` + +To see a list of options run: + +```sh +python phi2.py --help ``` [^1]: For more details on the model see the [blog post]( diff --git a/phi2/model.py b/phi2/phi2.py similarity index 100% rename from phi2/model.py rename to phi2/phi2.py From 0c1c500714aef1e05d3b1e032dda48667216fdd3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:37:34 -0800 Subject: [PATCH 21/33] update readme --- phi2/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/phi2/README.md b/phi2/README.md index aef47cd1..198ac30c 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,9 +1,11 @@ # Phi-2 -Phi-2 is a 2.7B parameter model released by Microsoft[^1] and trained on a mixture -of GPT-4 outputs and clean web-text. Its performance rivals much larger models. +Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with +performance that rivals much larger models. It was trained on a mixture of +GPT-4 outputs and clean web text. -Phi-2 efficiently runs on an Apple silicon device with 8 GB memory in 16-bit precision. +Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit +precision. ## Setup From 8f60d60814115659c1d9d6f911c7177a66e077e4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 09:19:44 -0800 Subject: [PATCH 22/33] cleanup conversion to use single qkv matrix --- phi2/README.md | 4 ++-- phi2/__init__.py | 0 phi2/convert.py | 48 ++----------------------------------------- phi2/phi2.py | 15 ++++++-------- phi2/requirements.txt | 1 + 5 files changed, 11 insertions(+), 57 deletions(-) delete mode 100644 phi2/__init__.py diff --git a/phi2/README.md b/phi2/README.md index 198ac30c..f5d80696 100644 --- a/phi2/README.md +++ b/phi2/README.md @@ -1,7 +1,7 @@ # Phi-2 -Phi-2 is a 2.7B parameter language model released by Microsoft[^1] with -performance that rivals much larger models. It was trained on a mixture of +Phi-2 is a 2.7B parameter language model released by Microsoft with +performance that rivals much larger models.[^1] It was trained on a mixture of GPT-4 outputs and clean web text. Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit diff --git a/phi2/__init__.py b/phi2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/phi2/convert.py b/phi2/convert.py index 3c821f69..4c625a6e 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -1,34 +1,5 @@ from transformers import AutoModelForCausalLM - -import numpy - - -def split_attention_matrix(state_dict, key) -> dict: - # "transformer.h.0.mixer" - _, model_dim = state_dict[key + ".weight"].shape - # (3 * model_dim, model_dim) - Wqkv_weight_key = key + ".weight" - Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :] - Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :] - Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :] - - # (3 * model_dim) - Wqkv_bias_key = key + ".bias" - Wq_bias = state_dict[Wqkv_bias_key][:model_dim] - Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim] - Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :] - - out_key = key.replace("mixer.Wqkv", "self_attention") - - return { - out_key + ".query_proj.weight": Wq_weight, - out_key + ".query_proj.bias": Wq_bias, - out_key + ".key_proj.weight": Wk_weight, - out_key + ".key_proj.bias": Wk_bias, - out_key + ".value_proj.weight": Wv_weight, - out_key + ".value_proj.bias": Wv_bias, - } - +import numpy as np def replace_key(key: str) -> str: if "wte.weight" in key: @@ -36,10 +7,6 @@ def replace_key(key: str) -> str: if ".mlp" in key: key = key.replace(".mlp", "") - - if ".mixer.out_proj" in key: - key = key.replace(".mixer", ".self_attention") - return key @@ -48,19 +15,8 @@ def convert(): "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True ) state_dict = model.state_dict() - keys = list(state_dict.keys()) - - for key in keys: - if ".mixer.Wqkv.weight" not in key: - continue - key_stub = key.rstrip(".weight") - state_dict.update(split_attention_matrix(state_dict, key_stub)) - - del state_dict[key_stub + ".weight"] - del state_dict[key_stub + ".bias"] - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights.npz", **weights) + np.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/phi2.py b/phi2/phi2.py index 38199c6c..7973c33d 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -31,15 +31,12 @@ class RoPEAttention(nn.Module): self.num_heads = num_heads self.rope = nn.RoPE(rotary_dim, traditional=False) - self.query_proj = nn.Linear(dims, dims) - self.key_proj = nn.Linear(dims, dims) - self.value_proj = nn.Linear(dims, dims) + self.Wqkv = nn.Linear(dims, 3 * dims) self.out_proj = nn.Linear(dims, dims) - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) + def __call__(self, x, mask=None, cache=None): + qkv = self.Wqkv(x) + queries, keys, values = mx.split(qkv, 3, axis=-1) # Extract some shapes num_heads = self.num_heads @@ -81,7 +78,7 @@ class ParallelBlock(nn.Module): super().__init__() dims = config.model_dim mlp_dims = dims * 4 - self.self_attention = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.ln = LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) @@ -89,7 +86,7 @@ class ParallelBlock(nn.Module): def __call__(self, x, mask, cache): h = self.ln(x) - attn_h, cache = self.self_attention(h, h, h, mask, cache) + attn_h, cache = self.mixer(h, mask, cache) ff_h = self.fc2(self.act(self.fc1(h))) return attn_h + ff_h + x, cache diff --git a/phi2/requirements.txt b/phi2/requirements.txt index 6a11f8d2..3e141ec3 100644 --- a/phi2/requirements.txt +++ b/phi2/requirements.txt @@ -1,3 +1,4 @@ einops mlx +numpy transformers From 5b08da2395191baba7421d7eeaaa9d3ffae02476 Mon Sep 17 00:00:00 2001 From: arpit Date: Thu, 14 Dec 2023 23:40:50 +0530 Subject: [PATCH 23/33] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37c977ed..7988e37a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX. Some more useful examples include: - [Transformer language model](transformer_lm) training. -- Large scale text generation with [LLaMA](llama) or [Mistral](mistral). +- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2). - Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) - Parameter efficient fine-tuning with [LoRA](lora). - Generating images with [Stable Diffusion](stable_diffusion). From b1b9b11801e4d86f36ac569e199d70b39f00bfe2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 12:09:10 -0800 Subject: [PATCH 24/33] updates + format --- cifar/README.md | 38 ++++++++++++++++++------------------ cifar/dataset.py | 21 ++++++-------------- cifar/main.py | 51 ++++++++++++++++++++++++++++-------------------- cifar/resnet.py | 1 - 4 files changed, 55 insertions(+), 56 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index abb2c0f5..118aef9e 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -1,9 +1,13 @@ # CIFAR and ResNets -An example of training a ResNet on CIFAR-10 with MLX. Several ResNet configurations in accordance with the original [paper](https://arxiv.org/abs/1512.03385) are available. Also illustrates how to use `mlx-data` to download and load the dataset. - +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet +configurations in accordance with the original +[paper](https://arxiv.org/abs/1512.03385) are available. The example also +illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to +load the dataset. ## Pre-requisites + Install the dependencies: ``` @@ -11,6 +15,7 @@ pip install -r requirements.txt ``` ## Running the example + Run the example with: ``` @@ -29,23 +34,18 @@ For all available options, run: python main.py --help ``` - -## Throughput - -On the tested device (M1 Macbook Pro, 16GB RAM), I get the following throughput with a `batch_size=256`: -``` -Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 415.39 images/sec -``` - -When training on just the CPU (with the `--cpu` argument), the throughput is significantly lower (almost 30x!): -``` -Epoch: 0 | avg. tr_loss 2.074 | avg. tr_acc 0.216 | Train Throughput: 13.5 images/sec -``` - ## Results -After training for 100 epochs, the following results were observed: + +After training with the default `resnet20` architecture for 100 epochs, you +should see the following results: + ``` -Epoch: 99 | avg. tr_loss 0.320 | avg. tr_acc 0.888 | Train Throughput: 416.77 images/sec -Epoch: 99 | test_acc 0.807 +Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec +Epoch: 99 | Test acc 0.807 ``` -At the time of writing, `mlx` doesn't have in-built `schedulers`, nor a `BatchNorm` layer. We'll revisit this example for exact reproduction once these features are added. \ No newline at end of file + +Note this was run on an M1 Macbook Pro with 16GB RAM. + +At the time of writing, `mlx` doesn't have built-in learning rate schedules, +nor a `BatchNorm` layer. We intend to update this example once these features +are added. diff --git a/cifar/dataset.py b/cifar/dataset.py index 29f558d1..89b10136 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -4,13 +4,15 @@ import math def get_cifar10(batch_size, root=None): - tr = load_cifar10(root=root) - num_tr_samples = tr.size() mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + def normalize(x): + x = x.astype("float32") / 255.0 + return (x - mean) / std + tr_iter = ( tr.shuffle() .to_stream() @@ -18,22 +20,11 @@ def get_cifar10(batch_size, root=None): .pad("image", 0, 4, 4, 0.0) .pad("image", 1, 4, 4, 0.0) .image_random_crop("image", 32, 32) - .key_transform("image", lambda x: (x.astype("float32") / 255.0)) - .key_transform("image", lambda x: (x - mean) / std) + .key_transform("image", normalize) .batch(batch_size) ) test = load_cifar10(root=root, train=False) - num_test_samples = test.size() - - test_iter = ( - test.to_stream() - .key_transform("image", lambda x: (x.astype("float32") / 255.0)) - .key_transform("image", lambda x: (x - mean) / std) - .batch(batch_size) - ) - - num_tr_steps_per_epoch = num_tr_samples // batch_size - num_test_steps_per_epoch = num_test_samples // batch_size + test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 29b0cbc7..26d06a6a 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -12,7 +12,8 @@ parser.add_argument( "--arch", type=str, default="resnet20", - help="model architecture [resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202]", + choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]], + help="model architecture", ) parser.add_argument("--batch_size", type=int, default=256, help="batch size") parser.add_argument("--epochs", type=int, default=100, help="number of epochs") @@ -21,10 +22,6 @@ parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") -def loss_fn(model, inp, tgt): - return mx.mean(nn.losses.cross_entropy(model(inp), tgt)) - - def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -50,17 +47,25 @@ def train_epoch(model, train_iter, optimizer, epoch): optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state) toc = time.perf_counter() - loss_value = loss.item() - acc_value = acc.item() - losses.append(loss_value) - accs.append(acc_value) - samples_per_sec.append(x.shape[0] / (toc - tic)) + loss = loss.item() + acc = acc.item() + losses.append(loss) + accs.append(acc) + throughput = x.shape[0] / (toc - tic) + samples_per_sec.append(throughput) if batch_counter % 10 == 0: print( - f"Epoch {epoch:02d} [{batch_counter:03d}] | tr_loss {loss_value:.3f} | tr_acc {acc_value:.3f} | Throughput: {x.shape[0] / (toc - tic):.2f} images/second" + " | ".join( + ( + f"Epoch {epoch:02d} [{batch_counter:03d}]", + f"Train loss {loss:.3f}", + f"Train acc {acc:.3f}", + f"Throughput: {throughput:.2f} images/second", + ) + ) ) - mean_tr_loss = mx.mean(mx.array(losses)) + eean_tr_loss = mx.mean(mx.array(losses)) mean_tr_acc = mx.mean(mx.array(accs)) samples_per_sec = mx.mean(mx.array(samples_per_sec)) return mean_tr_loss, mean_tr_acc, samples_per_sec @@ -81,24 +86,28 @@ def test_epoch(model, test_iter, epoch): def main(args): mx.random.seed(args.seed) - model = resnet.__dict__[args.arch]() + model = getattr(resnet, args.arch)() - print("num_params: {:0.04f} M".format(model.num_params() / 1e6)) - mx.eval(model.parameters()) + print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): - epoch_tr_loss, epoch_tr_acc, train_throughput = train_epoch( - model, train_data, optimizer, epoch - ) + tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) print( - f"Epoch: {epoch} | avg. tr_loss {epoch_tr_loss.item():.3f} | avg. tr_acc {epoch_tr_acc.item():.3f} | Train Throughput: {train_throughput.item():.2f} images/sec" + " | ".join( + ( + f"Epoch: {epoch}", + f"avg. Train loss {tr_loss.item():.3f}", + f"avg. Train acc {tr_acc.item():.3f}", + f"Throughput: {throughput.item():.2f} images/sec", + ) + ) ) - epoch_test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | test_acc {epoch_test_acc.item():.3f}") + test_acc = test_epoch(model, test_data, epoch) + print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") train_data.reset() test_data.reset() diff --git a/cifar/resnet.py b/cifar/resnet.py index 22b8a31a..758ee3de 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -59,7 +59,6 @@ class Block(nn.Module): self.shortcut = None def __call__(self, x): - out = nn.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.shortcut is None: From b9439ce74e3040ad7e91f49327720f4a0b0aa912 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 12:14:01 -0800 Subject: [PATCH 25/33] typo / nits --- cifar/README.md | 2 +- cifar/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index 118aef9e..d6bdaf9a 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -47,5 +47,5 @@ Epoch: 99 | Test acc 0.807 Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules, -nor a `BatchNorm` layer. We intend to update this example once these features +or a `BatchNorm` layer. We intend to update this example once these features are added. diff --git a/cifar/main.py b/cifar/main.py index 26d06a6a..829417b1 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -65,7 +65,7 @@ def train_epoch(model, train_iter, optimizer, epoch): ) ) - eean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_loss = mx.mean(mx.array(losses)) mean_tr_acc = mx.mean(mx.array(accs)) samples_per_sec = mx.mean(mx.array(samples_per_sec)) return mean_tr_loss, mean_tr_acc, samples_per_sec From 9b887cef08afd6fbee503bbd6c5c6efdf96a79ab Mon Sep 17 00:00:00 2001 From: Fahad Nadeem Date: Fri, 15 Dec 2023 03:09:33 +0500 Subject: [PATCH 26/33] minor dep fix in phi --- phi2/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/phi2/requirements.txt b/phi2/requirements.txt index 3e141ec3..2251ee12 100644 --- a/phi2/requirements.txt +++ b/phi2/requirements.txt @@ -2,3 +2,4 @@ einops mlx numpy transformers +torch From 078fed3d8d8cb24c1eda31f0009edf327659b914 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 15:30:32 -0800 Subject: [PATCH 27/33] use official HF for mixtral --- mixtral/README.md | 24 ++++++++---------------- mixtral/convert.py | 44 ++++++++++++++++++++++++++++++++++++++------ mixtral/mixtral.py | 10 +++++++--- mixtral/params.json | 1 + 4 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 mixtral/params.json diff --git a/mixtral/README.md b/mixtral/README.md index b56ee767..a90f7abf 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -17,36 +17,28 @@ brew install git-lfs Download the models from Hugging Face: ``` -git-lfs clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen -``` - -After that's done, combine the files: -``` -cd mixtral-8x7b-32kseqlen/ -cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/ +cd Mixtral-8x7B-v0.1/ && \ + git lfs pull --include "consolidated.*.pt" && \ + git lfs pull --include "tokenizer.model" ``` Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path mixtral-8x7b-32kseqlen/ +python convert.py --model_path Mixtral-8x7B-v0.1/ ``` The conversion script will save the converted weights in the same location. -After that's done, if you want to clean some stuff up: - -``` -rm mixtral-8x7b-32kseqlen/*.pth* -``` - ### Generate As easy as: ``` -python mixtral.py --model_path mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path Mixtral-8x7B-v0.1/ ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +[^mixtral]: Refer to Mistral's [blog + post](https://mistral.ai/news/mixtral-of-experts/) for more details. diff --git a/mixtral/convert.py b/mixtral/convert.py index e67f4453..d6ba8030 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -1,23 +1,55 @@ # Copyright © 2023 Apple Inc. import argparse +import glob +import json import numpy as np from pathlib import Path import torch +def convert(k, v, config): + v = v.to(torch.float16).numpy() + if "block_sparse_moe" not in k: + return [(k, v)] + if "gate" in k: + return [(k.replace("block_sparse_moe", "feed_forward"), v)] + + # From: layers.N.block_sparse_moe.w + # To: layers.N.experts.M.w + num_experts = args["moe"]["num_experts"] + key_path = k.split(".") + v = np.split(v, num_experts, axis=0) + if key_path[-1] == "w2": + v = [u.T for u in v] + + w_name = key_path.pop() + key_path[-1] = "feed_forward.experts" + return [ + (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) + ] + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen/", + default="Mixtral-8x7B-v0.1/", help="The path to the Mixtral model. The MLX model weights will also be saved there.", ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pth")) - np.savez( - str(model_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()}, - ) + + with open("params.json") as fid: + args = json.load(fid) + + torch_files = glob.glob(str(model_path / "consolidated.*.pt")) + torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) + for e, tf in enumerate(torch_files): + print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") + state = torch.load(tf) + new_state = {} + for k, v in state.items(): + new_state.update(convert(k, v, args)) + np.savez(str(model_path / f"weights.{e}.npz"), **new_state) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 1a9be600..59848219 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -2,6 +2,7 @@ import argparse from dataclasses import dataclass +import glob import json import numpy as np from pathlib import Path @@ -222,10 +223,13 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open("params.json", "r") as f: config = json.loads(f.read()) model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "weights.npz")) + weight_files = glob.glob(str(model_path / "weights.*.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mixtral(model_args) @@ -255,7 +259,7 @@ if __name__ == "__main__": parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen", + default="Mixtral-8x7B-v0.1", help="The path to the model weights, tokenizer, and config", ) parser.add_argument( diff --git a/mixtral/params.json b/mixtral/params.json new file mode 100644 index 00000000..f1016aa8 --- /dev/null +++ b/mixtral/params.json @@ -0,0 +1 @@ +{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}} From e434e7e5c2877535aea0aa6384fbe1f0f91f5646 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 15:40:38 -0800 Subject: [PATCH 28/33] incude instruct option --- mixtral/README.md | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/mixtral/README.md b/mixtral/README.md index a90f7abf..3b0c50d0 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -2,6 +2,8 @@ Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. +This example also supports the instruction fine-tuned Mixtral model.[^instruct] + Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. ### Setup @@ -16,9 +18,23 @@ brew install git-lfs Download the models from Hugging Face: +For the base model use: + ``` -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/ -cd Mixtral-8x7B-v0.1/ && \ +export MIXTRAL_MODEL=Mixtral-8x7B-v0.1 +``` + +For the instruction fine-tuned model use: + +``` +export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1 +``` + +Then run: + +``` +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/ +cd $MIXTRAL_MODEL/ && \ git lfs pull --include "consolidated.*.pt" && \ git lfs pull --include "tokenizer.model" ``` @@ -27,7 +43,7 @@ Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path Mixtral-8x7B-v0.1/ +python convert.py --model_path $MIXTRAL_MODEL/ ``` The conversion script will save the converted weights in the same location. @@ -37,8 +53,15 @@ The conversion script will save the converted weights in the same location. As easy as: ``` -python mixtral.py --model_path Mixtral-8x7B-v0.1/ +python mixtral.py --model_path $MIXTRAL_MODEL/ ``` -[^mixtral]: Refer to Mistral's [blog - post](https://mistral.ai/news/mixtral-of-experts/) for more details. +For more options including how to prompt the model, run: + +``` +python mixtral.py --help +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more +details From 4549dcbbd03f48c99b14fccd82908c76c48adcd8 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:50:59 -0800 Subject: [PATCH 29/33] Stop generating at eos token --- phi2/phi2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index 7973c33d..ede79ea2 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,7 +202,11 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - tokens.append(token) + + if token == tokenizer.eos_token_id: + break + else: + tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) From d7d7aabded3b52d36c5f3a3675553d5651639b6f Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:52:22 -0800 Subject: [PATCH 30/33] Remove unnecessary return --- phi2/phi2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index ede79ea2..2d3f792a 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,7 +202,6 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - if token == tokenizer.eos_token_id: break else: From b863e7cca0405461c5239f503748ab2f62cef241 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 16:56:50 -0800 Subject: [PATCH 31/33] format --- llama/llama.py | 2 +- mixtral/README.md | 8 +++++++- phi2/convert.py | 1 + whisper/test.py | 1 - whisper/whisper/decoding.py | 2 +- whisper/whisper/load_models.py | 7 ++++--- whisper/whisper/transcribe.py | 2 +- whisper/whisper/whisper.py | 21 ++++++++++++++++++--- 8 files changed, 33 insertions(+), 11 deletions(-) diff --git a/llama/llama.py b/llama/llama.py index 9b8157b7..73eb39c5 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -315,7 +315,7 @@ def load_model(model_path): config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] if config.get("vocab_size", -1) < 0: config["vocab_size"] = weights["output.weight"].shape[-1] - unused = ["multiple_of", "ffn_dim_multiplier", 'rope_theta'] + unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] for k in unused: if k in config: config.pop(k) diff --git a/mixtral/README.md b/mixtral/README.md index 3b0c50d0..9194979e 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -62,6 +62,12 @@ For more options including how to prompt the model, run: python mixtral.py --help ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +For the Instruction model, make sure to follow the prompt format: + +``` +[INST] Instruction prompt [/INST] +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details. [^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more details diff --git a/phi2/convert.py b/phi2/convert.py index 4c625a6e..5aa07dce 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -1,6 +1,7 @@ from transformers import AutoModelForCausalLM import numpy as np + def replace_key(key: str) -> str: if "wte.weight" in key: key = "wte.weight" diff --git a/whisper/test.py b/whisper/test.py index 79f233ba..3e7630a9 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase): logits = mlx_model(mels, tokens) self.assertEqual(logits.dtype, mx.float16) - def test_decode_lang(self): options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index 7c7c4a93..d5025444 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -112,7 +112,7 @@ class DecodingOptions: max_initial_timestamp: Optional[float] = 1.0 # implementation details - fp16: bool = True # use fp16 for most of the calculation + fp16: bool = True # use fp16 for most of the calculation @dataclass(frozen=True) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 58cef9ac..ffdccf44 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = { "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", - "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } @@ -166,7 +166,8 @@ def convert(model, rules=None): def torch_to_mlx( - torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, + torch_model: torch_whisper.Whisper, + dtype: mx.Dtype = mx.float16, ) -> whisper.Whisper: def convert_rblock(model, rules): children = dict(model.named_children()) @@ -194,6 +195,6 @@ def torch_to_mlx( def load_model( name: str, download_root: str = None, - dtype : mx.Dtype = mx.float32, + dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: return torch_to_mlx(load_torch_model(name, download_root), dtype) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 3172bdb3..06f3c9ea 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -43,7 +43,7 @@ class ModelHolder: model_name = None @classmethod - def get_model(cls, model: str, dtype : mx.Dtype): + def get_model(cls, model: str, dtype: mx.Dtype): if cls.model is None or model != cls.model_name: cls.model = load_model(model, dtype=dtype) cls.model_name = model diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index bca69946..8ee6d7d9 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) + class LayerNorm(nn.LayerNorm): def __call__(self, x: mx.array) -> mx.array: return super().__call__(x.astype(mx.float32)).astype(x.dtype) @@ -123,7 +124,13 @@ class ResidualAttentionBlock(nn.Module): class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) @@ -148,7 +155,13 @@ class AudioEncoder(nn.Module): class TextDecoder(nn.Module): def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_vocab: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() @@ -160,7 +173,9 @@ class TextDecoder(nn.Module): for _ in range(n_layer) ] self.ln = LayerNorm(n_state) - self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype( + dtype + ) def __call__(self, x, xa, kv_cache=None): """ From d74d9453ddcace552828816a4114b0234febf837 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 21:11:23 -0800 Subject: [PATCH 32/33] Refactor EOS check --- phi2/phi2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index 2d3f792a..4a9ed30e 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,16 +202,20 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - if token == tokenizer.eos_token_id: - break - else: - tokens.append(token) + tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) From ec1176352746e66152daa8558779e2c59ab7a51e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 21:45:25 -0800 Subject: [PATCH 33/33] fix RoPE bug + minor updates --- mixtral/mixtral.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 59848219..16a9eec8 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -41,6 +41,26 @@ class RMSNorm(nn.Module): return self.weight * output +class RoPE(nn.RoPE): + def __init__(self, dims: int, traditional: bool = False): + super().__init__(dims, traditional) + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, base=1000000, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -57,7 +77,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.rope = RoPE(args.head_dim, traditional=True) def __call__( self, @@ -126,7 +146,10 @@ class MOEFeedForward(nn.Module): gates = self.gate(x) inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) + scores = mx.softmax( + mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), + axis=-1, + ).astype(gates.dtype) y = [] for xt, st, it in zip(x, scores, inds.tolist()): @@ -182,8 +205,9 @@ class Mixtral(nn.Module): h = self.tok_embeddings(inputs) mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) mask = mask.astype(h.dtype) if cache is None: @@ -192,7 +216,7 @@ class Mixtral(nn.Module): for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.output(self.norm(h)), cache + return self.output(self.norm(h[:, T - 1 : T, :])), cache class Tokenizer: @@ -278,7 +302,7 @@ if __name__ == "__main__": "--temp", help="The sampling temperature.", type=float, - default=1.0, + default=0.0, ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")