diff --git a/README.md b/README.md index 37c977ed..7988e37a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The [MNIST](mnist) example is a good starting point to learn how to use MLX. Some more useful examples include: - [Transformer language model](transformer_lm) training. -- Large scale text generation with [LLaMA](llama) or [Mistral](mistral). +- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2). - Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) - Parameter efficient fine-tuning with [LoRA](lora). - Generating images with [Stable Diffusion](stable_diffusion). diff --git a/bert/README.md b/bert/README.md index cea738df..70bc39a9 100644 --- a/bert/README.md +++ b/bert/README.md @@ -8,7 +8,7 @@ The `convert.py` script relies on `transformers` to download the weights, and ex ``` python convert.py \ - --bert-model bert-base-uncased + --bert-model bert-base-uncased \ --mlx-model weights/bert-base-uncased.npz ``` diff --git a/bert/model.py b/bert/model.py index 794254f6..ff73dea2 100644 --- a/bert/model.py +++ b/bert/model.py @@ -8,7 +8,6 @@ import mlx.core as mx import mlx.nn as nn import argparse import numpy -import math @dataclass @@ -35,74 +34,6 @@ model_configs = { } -class MultiHeadAttention(nn.Module): - """ - Minor update to the MultiHeadAttention module to ensure that the - projections use bias. - """ - - def __init__( - self, - dims: int, - num_heads: int, - query_input_dims: Optional[int] = None, - key_input_dims: Optional[int] = None, - value_input_dims: Optional[int] = None, - value_dims: Optional[int] = None, - value_output_dims: Optional[int] = None, - ): - super().__init__() - - if (dims % num_heads) != 0: - raise ValueError( - f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" - ) - - query_input_dims = query_input_dims or dims - key_input_dims = key_input_dims or dims - value_input_dims = value_input_dims or key_input_dims - value_dims = value_dims or dims - value_output_dims = value_output_dims or dims - - self.num_heads = num_heads - self.query_proj = nn.Linear(query_input_dims, dims, True) - self.key_proj = nn.Linear(key_input_dims, dims, True) - self.value_proj = nn.Linear(value_input_dims, value_dims, True) - self.out_proj = nn.Linear(value_dims, value_output_dims, True) - - def __call__(self, queries, keys, values, mask=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys - if mask is not None: - mask = self.convert_mask_to_additive_causal_mask(mask) - mask = mx.expand_dims(mask, (1, 2)) - mask = mx.broadcast_to(mask, scores.shape) - scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat) - - def convert_mask_to_additive_causal_mask( - self, mask: mx.array, dtype: mx.Dtype = mx.float32 - ) -> mx.array: - mask = mask == 0 - mask = mask.astype(dtype) * -1e9 - return mask - - class TransformerEncoderLayer(nn.Module): """ A transformer encoder layer with (the original BERT) post-normalization. @@ -117,7 +48,7 @@ class TransformerEncoderLayer(nn.Module): ): super().__init__() mlp_dims = mlp_dims or dims * 4 - self.attention = MultiHeadAttention(dims, num_heads) + self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) self.linear1 = nn.Linear(dims, mlp_dims) @@ -187,9 +118,15 @@ class Bert(nn.Module): self, input_ids: mx.array, token_type_ids: mx.array, - attention_mask: Optional[mx.array] = None, + attention_mask: mx.array = None, ) -> tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = mx.log(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + y = self.encoder(x, attention_mask) return y, mx.tanh(self.pooler(y[:, 0])) diff --git a/bert/requirements.txt b/bert/requirements.txt index 24266334..a6b564c5 100644 --- a/bert/requirements.txt +++ b/bert/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.0.5 transformers -numpy \ No newline at end of file +numpy diff --git a/cifar/README.md b/cifar/README.md new file mode 100644 index 00000000..d6bdaf9a --- /dev/null +++ b/cifar/README.md @@ -0,0 +1,51 @@ +# CIFAR and ResNets + +An example of training a ResNet on CIFAR-10 with MLX. Several ResNet +configurations in accordance with the original +[paper](https://arxiv.org/abs/1512.03385) are available. The example also +illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to +load the dataset. + +## Pre-requisites + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +## Running the example + +Run the example with: + +``` +python main.py +``` + +By default the example runs on the GPU. To run on the CPU, use: + +``` +python main.py --cpu +``` + +For all available options, run: + +``` +python main.py --help +``` + +## Results + +After training with the default `resnet20` architecture for 100 epochs, you +should see the following results: + +``` +Epoch: 99 | avg. Train loss 0.320 | avg. Train acc 0.888 | Throughput: 416.77 images/sec +Epoch: 99 | Test acc 0.807 +``` + +Note this was run on an M1 Macbook Pro with 16GB RAM. + +At the time of writing, `mlx` doesn't have built-in learning rate schedules, +or a `BatchNorm` layer. We intend to update this example once these features +are added. diff --git a/cifar/dataset.py b/cifar/dataset.py new file mode 100644 index 00000000..89b10136 --- /dev/null +++ b/cifar/dataset.py @@ -0,0 +1,30 @@ +import mlx.core as mx +from mlx.data.datasets import load_cifar10 +import math + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root) + + mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) + std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + + def normalize(x): + x = x.astype("float32") / 255.0 + return (x - mean) / std + + tr_iter = ( + tr.shuffle() + .to_stream() + .image_random_h_flip("image", prob=0.5) + .pad("image", 0, 4, 4, 0.0) + .pad("image", 1, 4, 4, 0.0) + .image_random_crop("image", 32, 32) + .key_transform("image", normalize) + .batch(batch_size) + ) + + test = load_cifar10(root=root, train=False) + test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + + return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py new file mode 100644 index 00000000..829417b1 --- /dev/null +++ b/cifar/main.py @@ -0,0 +1,120 @@ +import argparse +import time +import resnet +import mlx.nn as nn +import mlx.core as mx +import mlx.optimizers as optim +from dataset import get_cifar10 + + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument( + "--arch", + type=str, + default="resnet20", + choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]], + help="model architecture", +) +parser.add_argument("--batch_size", type=int, default=256, help="batch size") +parser.add_argument("--epochs", type=int, default=100, help="number of epochs") +parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") +parser.add_argument("--seed", type=int, default=0, help="random seed") +parser.add_argument("--cpu", action="store_true", help="use cpu only") + + +def eval_fn(model, inp, tgt): + return mx.mean(mx.argmax(model(inp), axis=1) == tgt) + + +def train_epoch(model, train_iter, optimizer, epoch): + def train_step(model, inp, tgt): + output = model(inp) + loss = mx.mean(nn.losses.cross_entropy(output, tgt)) + acc = mx.mean(mx.argmax(output, axis=1) == tgt) + return loss, acc + + train_step_fn = nn.value_and_grad(model, train_step) + + losses = [] + accs = [] + samples_per_sec = [] + + for batch_counter, batch in enumerate(train_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + tic = time.perf_counter() + (loss, acc), grads = train_step_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + toc = time.perf_counter() + loss = loss.item() + acc = acc.item() + losses.append(loss) + accs.append(acc) + throughput = x.shape[0] / (toc - tic) + samples_per_sec.append(throughput) + if batch_counter % 10 == 0: + print( + " | ".join( + ( + f"Epoch {epoch:02d} [{batch_counter:03d}]", + f"Train loss {loss:.3f}", + f"Train acc {acc:.3f}", + f"Throughput: {throughput:.2f} images/second", + ) + ) + ) + + mean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_acc = mx.mean(mx.array(accs)) + samples_per_sec = mx.mean(mx.array(samples_per_sec)) + return mean_tr_loss, mean_tr_acc, samples_per_sec + + +def test_epoch(model, test_iter, epoch): + accs = [] + for batch_counter, batch in enumerate(test_iter): + x = mx.array(batch["image"]) + y = mx.array(batch["label"]) + acc = eval_fn(model, x, y) + acc_value = acc.item() + accs.append(acc_value) + mean_acc = mx.mean(mx.array(accs)) + return mean_acc + + +def main(args): + mx.random.seed(args.seed) + + model = getattr(resnet, args.arch)() + + print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + + optimizer = optim.Adam(learning_rate=args.lr) + + train_data, test_data = get_cifar10(args.batch_size) + for epoch in range(args.epochs): + tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) + print( + " | ".join( + ( + f"Epoch: {epoch}", + f"avg. Train loss {tr_loss.item():.3f}", + f"avg. Train acc {tr_acc.item():.3f}", + f"Throughput: {throughput.item():.2f} images/sec", + ) + ) + ) + + test_acc = test_epoch(model, test_data, epoch) + print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + + train_data.reset() + test_data.reset() + + +if __name__ == "__main__": + args = parser.parse_args() + if args.cpu: + mx.set_default_device(mx.cpu) + main(args) diff --git a/cifar/requirements.txt b/cifar/requirements.txt new file mode 100644 index 00000000..6ff78a64 --- /dev/null +++ b/cifar/requirements.txt @@ -0,0 +1,2 @@ +mlx +mlx-data \ No newline at end of file diff --git a/cifar/resnet.py b/cifar/resnet.py new file mode 100644 index 00000000..758ee3de --- /dev/null +++ b/cifar/resnet.py @@ -0,0 +1,131 @@ +""" +Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385]. +Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. + +There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. +""" + +from typing import Any +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + + +__all__ = [ + "ResNet", + "resnet20", + "resnet32", + "resnet44", + "resnet56", + "resnet110", + "resnet1202", +] + + +class ShortcutA(nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def __call__(self, x): + return mx.pad( + x[:, ::2, ::2, :], + pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)], + ) + + +class Block(nn.Module): + """ + Implements a ResNet block with two convolutional layers and a skip connection. + As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) + """ + + def __init__(self, in_dims, dims, stride=1): + super().__init__() + + self.conv1 = nn.Conv2d( + in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.LayerNorm(dims) + + self.conv2 = nn.Conv2d( + dims, dims, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.LayerNorm(dims) + + if stride != 1: + self.shortcut = ShortcutA(dims) + else: + self.shortcut = None + + def __call__(self, x): + out = nn.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + if self.shortcut is None: + out += x + else: + out += self.shortcut(x) + out = nn.relu(out) + return out + + +class ResNet(nn.Module): + """ + Creates a ResNet model for CIFAR-10, as specified in the original paper. + """ + + def __init__(self, block, num_blocks, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.LayerNorm(16) + + self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2) + + self.linear = nn.Linear(64, num_classes) + + def _make_layer(self, block, in_dims, dims, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(in_dims, dims, stride)) + in_dims = dims + return nn.Sequential(*layers) + + def num_params(self): + nparams = sum(x.size for k, x in tree_flatten(self.parameters())) + return nparams + + def __call__(self, x): + x = nn.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1) + x = self.linear(x) + return x + + +def resnet20(**kwargs): + return ResNet(Block, [3, 3, 3], **kwargs) + + +def resnet32(**kwargs): + return ResNet(Block, [5, 5, 5], **kwargs) + + +def resnet44(**kwargs): + return ResNet(Block, [7, 7, 7], **kwargs) + + +def resnet56(**kwargs): + return ResNet(Block, [9, 9, 9], **kwargs) + + +def resnet110(**kwargs): + return ResNet(Block, [18, 18, 18], **kwargs) + + +def resnet1202(**kwargs): + return ResNet(Block, [200, 200, 200], **kwargs) diff --git a/llama/llama.py b/llama/llama.py index db9c8db3..73eb39c5 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -315,7 +315,7 @@ def load_model(model_path): config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] if config.get("vocab_size", -1) < 0: config["vocab_size"] = weights["output.weight"].shape[-1] - unused = ["multiple_of", "ffn_dim_multiplie"] + unused = ["multiple_of", "ffn_dim_multiplier", "rope_theta"] for k in unused: if k in config: config.pop(k) diff --git a/lora/README.md b/lora/README.md index 09911bd3..a6819950 100644 --- a/lora/README.md +++ b/lora/README.md @@ -24,7 +24,7 @@ tar -xf mistral-7B-v0.1.tar ``` If you do not have access to the Llama weights you will need to [request -access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) +access](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) from Meta. Convert the model with: diff --git a/mixtral/README.md b/mixtral/README.md index 417759e1..9194979e 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -2,6 +2,8 @@ Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. +This example also supports the instruction fine-tuned Mixtral model.[^instruct] + Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. ### Setup @@ -16,37 +18,56 @@ brew install git-lfs Download the models from Hugging Face: +For the base model use: + ``` -git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen +export MIXTRAL_MODEL=Mixtral-8x7B-v0.1 ``` -After that's done, combine the files: +For the instruction fine-tuned model use: + ``` -cd mixtral-8x7b-32kseqlen/ -cat consolidated.00.pth-split0 consolidated.00.pth-split1 consolidated.00.pth-split2 consolidated.00.pth-split3 consolidated.00.pth-split4 consolidated.00.pth-split5 consolidated.00.pth-split6 consolidated.00.pth-split7 consolidated.00.pth-split8 consolidated.00.pth-split9 consolidated.00.pth-split10 > consolidated.00.pth +export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1 +``` + +Then run: + +``` +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/ +cd $MIXTRAL_MODEL/ && \ + git lfs pull --include "consolidated.*.pt" && \ + git lfs pull --include "tokenizer.model" ``` Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so MLX can read them: ``` -python convert.py --model_path mixtral-8x7b-32kseqlen/ +python convert.py --model_path $MIXTRAL_MODEL/ ``` The conversion script will save the converted weights in the same location. -After that's done, if you want to clean some stuff up: - -``` -rm mixtral-8x7b-32kseqlen/*.pth* -``` - ### Generate As easy as: ``` -python mixtral.py --model_path mixtral-8x7b-32kseqlen/ +python mixtral.py --model_path $MIXTRAL_MODEL/ ``` -[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details. +For more options including how to prompt the model, run: + +``` +python mixtral.py --help +``` + +For the Instruction model, make sure to follow the prompt format: + +``` +[INST] Instruction prompt [/INST] +``` + +[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details. +[^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more +details diff --git a/mixtral/convert.py b/mixtral/convert.py index e67f4453..d6ba8030 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -1,23 +1,55 @@ # Copyright © 2023 Apple Inc. import argparse +import glob +import json import numpy as np from pathlib import Path import torch +def convert(k, v, config): + v = v.to(torch.float16).numpy() + if "block_sparse_moe" not in k: + return [(k, v)] + if "gate" in k: + return [(k.replace("block_sparse_moe", "feed_forward"), v)] + + # From: layers.N.block_sparse_moe.w + # To: layers.N.experts.M.w + num_experts = args["moe"]["num_experts"] + key_path = k.split(".") + v = np.split(v, num_experts, axis=0) + if key_path[-1] == "w2": + v = [u.T for u in v] + + w_name = key_path.pop() + key_path[-1] = "feed_forward.experts" + return [ + (".".join(key_path + [str(e), w_name, "weight"]), u) for e, u in enumerate(v) + ] + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen/", + default="Mixtral-8x7B-v0.1/", help="The path to the Mixtral model. The MLX model weights will also be saved there.", ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pth")) - np.savez( - str(model_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()}, - ) + + with open("params.json") as fid: + args = json.load(fid) + + torch_files = glob.glob(str(model_path / "consolidated.*.pt")) + torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) + for e, tf in enumerate(torch_files): + print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") + state = torch.load(tf) + new_state = {} + for k, v in state.items(): + new_state.update(convert(k, v, args)) + np.savez(str(model_path / f"weights.{e}.npz"), **new_state) diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index 1a9be600..16a9eec8 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -2,6 +2,7 @@ import argparse from dataclasses import dataclass +import glob import json import numpy as np from pathlib import Path @@ -40,6 +41,26 @@ class RMSNorm(nn.Module): return self.weight * output +class RoPE(nn.RoPE): + def __init__(self, dims: int, traditional: bool = False): + super().__init__(dims, traditional) + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, base=1000000, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -56,7 +77,7 @@ class Attention(nn.Module): self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.rope = nn.RoPE(args.head_dim, traditional=True) + self.rope = RoPE(args.head_dim, traditional=True) def __call__( self, @@ -125,7 +146,10 @@ class MOEFeedForward(nn.Module): gates = self.gate(x) inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) + scores = mx.softmax( + mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), + axis=-1, + ).astype(gates.dtype) y = [] for xt, st, it in zip(x, scores, inds.tolist()): @@ -181,8 +205,9 @@ class Mixtral(nn.Module): h = self.tok_embeddings(inputs) mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) mask = mask.astype(h.dtype) if cache is None: @@ -191,7 +216,7 @@ class Mixtral(nn.Module): for e, layer in enumerate(self.layers): h, cache[e] = layer(h, mask, cache[e]) - return self.output(self.norm(h)), cache + return self.output(self.norm(h[:, T - 1 : T, :])), cache class Tokenizer: @@ -222,10 +247,13 @@ class Tokenizer: def load_model(folder: str, dtype=mx.float16): model_path = Path(folder) tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "params.json", "r") as f: + with open("params.json", "r") as f: config = json.loads(f.read()) model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "weights.npz")) + weight_files = glob.glob(str(model_path / "weights.*.npz")) + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mixtral(model_args) @@ -255,7 +283,7 @@ if __name__ == "__main__": parser.add_argument( "--model_path", type=str, - default="mixtral-8x7b-32kseqlen", + default="Mixtral-8x7B-v0.1", help="The path to the model weights, tokenizer, and config", ) parser.add_argument( @@ -274,7 +302,7 @@ if __name__ == "__main__": "--temp", help="The sampling temperature.", type=float, - default=1.0, + default=0.0, ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") diff --git a/mixtral/params.json b/mixtral/params.json new file mode 100644 index 00000000..f1016aa8 --- /dev/null +++ b/mixtral/params.json @@ -0,0 +1 @@ +{"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}} diff --git a/phi2/.gitignore b/phi2/.gitignore new file mode 100644 index 00000000..258ec872 --- /dev/null +++ b/phi2/.gitignore @@ -0,0 +1 @@ +weights.npz diff --git a/phi2/README.md b/phi2/README.md new file mode 100644 index 00000000..f5d80696 --- /dev/null +++ b/phi2/README.md @@ -0,0 +1,57 @@ +# Phi-2 + +Phi-2 is a 2.7B parameter language model released by Microsoft with +performance that rivals much larger models.[^1] It was trained on a mixture of +GPT-4 outputs and clean web text. + +Phi-2 efficiently runs on Apple silicon devices with 8GB of memory in 16-bit +precision. + +## Setup + +Download and convert the model: + +```sh +python convert.py +``` + +This will make the `weights.npz` file which MLX can read. + +## Generate + +To generate text with the default prompt: + +```sh +python phi2.py +``` + +Should give the output: + +``` +Answer: Mathematics is like a lighthouse that guides us through the darkness of +uncertainty. Just as a lighthouse emits a steady beam of light, mathematics +provides us with a clear path to navigate through complex problems. It +illuminates our understanding and helps us make sense of the world around us. + +Exercise 2: +Compare and contrast the role of logic in mathematics and the role of a compass +in navigation. + +Answer: Logic in mathematics is like a compass in navigation. It helps +``` + +To use your own prompt: + +```sh +python phi2.py --prompt --max_tokens +``` + +To see a list of options run: + +```sh +python phi2.py --help +``` + +[^1]: For more details on the model see the [blog post]( +https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) +and the [Hugging Face repo](https://huggingface.co/microsoft/phi-2) diff --git a/phi2/convert.py b/phi2/convert.py new file mode 100644 index 00000000..5aa07dce --- /dev/null +++ b/phi2/convert.py @@ -0,0 +1,24 @@ +from transformers import AutoModelForCausalLM +import numpy as np + + +def replace_key(key: str) -> str: + if "wte.weight" in key: + key = "wte.weight" + + if ".mlp" in key: + key = key.replace(".mlp", "") + return key + + +def convert(): + model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype="auto", trust_remote_code=True + ) + state_dict = model.state_dict() + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + np.savez("weights.npz", **weights) + + +if __name__ == "__main__": + convert() diff --git a/phi2/phi2.py b/phi2/phi2.py new file mode 100644 index 00000000..4a9ed30e --- /dev/null +++ b/phi2/phi2.py @@ -0,0 +1,222 @@ +import argparse +from typing import Optional +from dataclasses import dataclass +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer + +import mlx.core as mx +import mlx.nn as nn +import math + + +@dataclass +class ModelArgs: + max_sequence_length: int = 2048 + num_vocab: int = 51200 + model_dim: int = 2560 + num_heads: int = 32 + num_layers: int = 32 + rotary_dim: int = 32 + + +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + + +class RoPEAttention(nn.Module): + def __init__(self, dims: int, num_heads: int, rotary_dim: int): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(rotary_dim, traditional=False) + self.Wqkv = nn.Linear(dims, 3 * dims) + self.out_proj = nn.Linear(dims, dims) + + def __call__(self, x, mask=None, cache=None): + qkv = self.Wqkv(x) + queries, keys, values = mx.split(qkv, 3, axis=-1) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + queries = queries.astype(mx.float32) + keys = keys.astype(mx.float32) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores, axis=-1).astype(values.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class ParallelBlock(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + dims = config.model_dim + mlp_dims = dims * 4 + self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) + self.ln = LayerNorm(dims) + self.fc1 = nn.Linear(dims, mlp_dims) + self.fc2 = nn.Linear(mlp_dims, dims) + self.act = nn.GELU(approx="precise") + + def __call__(self, x, mask, cache): + h = self.ln(x) + attn_h, cache = self.mixer(h, mask, cache) + ff_h = self.fc2(self.act(self.fc1(h))) + return attn_h + ff_h + x, cache + + +class TransformerDecoder(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.h = [ParallelBlock(config) for i in range(config.num_layers)] + + def __call__(self, x, mask, cache): + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) + + +class Phi2(nn.Module): + def __init__(self, config: ModelArgs): + self.wte = nn.Embedding(config.num_vocab, config.model_dim) + self.transformer = TransformerDecoder(config) + self.lm_head = OutputHead(config) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache: mx.array = None, + ) -> tuple[mx.array, mx.array]: + x = self.wte(inputs) + + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache + + +def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) + yield y + + +def load_model(): + model = Phi2(ModelArgs()) + weights = mx.load("weights.npz") + model.update(tree_unflatten(list(weights.items()))) + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model() + + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + + prompt = mx.array(prompt) + + print("[INFO] Generating with Phi-2...", flush=True) + print(args.prompt, end="", flush=True) + + tokens = [] + for token, _ in zip(generate(prompt, model), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + + if eos_index is not None: + tokens = tokens[:eos_index] + + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + if eos_index is not None: + break + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True) diff --git a/phi2/requirements.txt b/phi2/requirements.txt new file mode 100644 index 00000000..2251ee12 --- /dev/null +++ b/phi2/requirements.txt @@ -0,0 +1,5 @@ +einops +mlx +numpy +transformers +torch diff --git a/stable_diffusion/README.md b/stable_diffusion/README.md index 400a50f7..5e44cb1a 100644 --- a/stable_diffusion/README.md +++ b/stable_diffusion/README.md @@ -27,7 +27,7 @@ Usage ------ Although each component in this repository can be used by itself, the fastest -way to get started is by using the `StableDiffusion` class from the `diffusion` +way to get started is by using the `StableDiffusion` class from the `stable_diffusion` module. ```python diff --git a/whisper/README.md b/whisper/README.md index 602941a0..7df1382f 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -1,7 +1,7 @@ -# whisper +# Whisper Speech recognition with Whisper in MLX. Whisper is a set of open source speech -recognition models from Open AI, ranging from 39 million to 1.5 billion +recognition models from OpenAI, ranging from 39 million to 1.5 billion parameters[^1]. ### Setup @@ -15,7 +15,7 @@ pip install -r requirements.txt Install [`ffmpeg`](https://ffmpeg.org/): ``` -# on MacOS using Homebrew (https://brew.sh/) +# on macOS using Homebrew (https://brew.sh/) brew install ffmpeg ``` diff --git a/whisper/test.py b/whisper/test.py index 79f233ba..3e7630a9 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -65,7 +65,6 @@ class TestWhisper(unittest.TestCase): logits = mlx_model(mels, tokens) self.assertEqual(logits.dtype, mx.float16) - def test_decode_lang(self): options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index 7c7c4a93..d5025444 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -112,7 +112,7 @@ class DecodingOptions: max_initial_timestamp: Optional[float] = 1.0 # implementation details - fp16: bool = True # use fp16 for most of the calculation + fp16: bool = True # use fp16 for most of the calculation @dataclass(frozen=True) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 58cef9ac..ffdccf44 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -44,7 +44,7 @@ _ALIGNMENT_HEADS = { "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", - "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00" + "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", } @@ -166,7 +166,8 @@ def convert(model, rules=None): def torch_to_mlx( - torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, + torch_model: torch_whisper.Whisper, + dtype: mx.Dtype = mx.float16, ) -> whisper.Whisper: def convert_rblock(model, rules): children = dict(model.named_children()) @@ -194,6 +195,6 @@ def torch_to_mlx( def load_model( name: str, download_root: str = None, - dtype : mx.Dtype = mx.float32, + dtype: mx.Dtype = mx.float32, ) -> whisper.Whisper: return torch_to_mlx(load_torch_model(name, download_root), dtype) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 3172bdb3..06f3c9ea 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -43,7 +43,7 @@ class ModelHolder: model_name = None @classmethod - def get_model(cls, model: str, dtype : mx.Dtype): + def get_model(cls, model: str, dtype: mx.Dtype): if cls.model is None or model != cls.model_name: cls.model = load_model(model, dtype=dtype) cls.model_name = model diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 62e43de3..8ee6d7d9 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,6 +37,7 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) + class LayerNorm(nn.LayerNorm): def __call__(self, x: mx.array) -> mx.array: return super().__call__(x.astype(mx.float32)).astype(x.dtype) @@ -117,13 +118,19 @@ class ResidualAttentionBlock(nn.Module): if self.cross_attn: y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) x += y - x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) + x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) return x, (kv, cross_kv) class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) @@ -134,8 +141,8 @@ class AudioEncoder(nn.Module): self.ln_post = LayerNorm(n_state) def __call__(self, x): - x = nn.gelu(self.conv1(x)) - x = nn.gelu(self.conv2(x)) + x = nn.gelu(self.conv1(x)).astype(x.dtype) + x = nn.gelu(self.conv2(x)).astype(x.dtype) assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" x = x + self._positional_embedding @@ -148,7 +155,13 @@ class AudioEncoder(nn.Module): class TextDecoder(nn.Module): def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, + self, + n_vocab: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, ): super().__init__() @@ -160,7 +173,9 @@ class TextDecoder(nn.Module): for _ in range(n_layer) ] self.ln = LayerNorm(n_state) - self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype( + dtype + ) def __call__(self, x, xa, kv_cache=None): """