diff --git a/lora/README.md b/lora/README.md index dffdd997..69ca3899 100644 --- a/lora/README.md +++ b/lora/README.md @@ -1,7 +1,8 @@ # LoRA -This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low -rank adaptation (LoRA)[^lora] for a target task. +This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a +Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target +task. In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to generate SQL queries from natural language. However, the example is intended to @@ -15,19 +16,27 @@ Install the dependencies: pip install -r requirements.txt ``` -Next, download and convert the model. If you do not have access to the model -weights you will need to [request +Next, download and convert the model. The Mistral weights can be downloaded with: + +``` +curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar +tar -xf mistral-7B-v0.1.tar +``` + +If you do not have access to the Llama weights you will need to [request access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) from Meta. Convert the weights with: ``` -python convert.py mlx_llama_7B.npz +python convert.py ``` ## Run +#### Fine-tune + The main script is `lora.py`. To see a full list of options run ``` @@ -37,28 +46,34 @@ python lora.py --help To fine-tune a model use: ``` -python lora.py --model mlx_llama_7B.npz \ - --tokenizer tokenizer.model \ +python lora.py --model \ --train \ - --iters 600 \ + --iters 600 ``` +Note, the model path should have the MLX weights, the tokenizer, and the +`params.json` configuration which will all be output by the `conver.py` script. + By default, the adapter weights are saved in `adapters.npz`. You can specify the output location with `--adapter_file`. +#### Evaluate + To compute test set perplexity use ``` -python lora.py --model mlx_llama_7B.npz \ - --tokenizer tokenizer.model \ +python lora.py --model \ + --adapter_file \ --test ``` +#### Generate + For generation use ``` -python lora.py --model mlx_llama_7B.npz \ - --tokenizer tokenizer.model \ +python lora.py --model \ + --adapter_file \ --num-tokens 50 \ --prompt "table: 1-10015132-16 columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team @@ -87,4 +102,5 @@ The model trains at around 475 tokens per second on an M2 Ultra. [^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. [^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. +[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. [^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL. diff --git a/lora/convert.py b/lora/convert.py index 2ce247a3..2903aae8 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -1,53 +1,61 @@ # Copyright © 2023 Apple Inc. import argparse -from itertools import starmap - +import json import numpy as np +from pathlib import Path +import shutil +import os import torch -def map_torch_to_mlx(key, value): - if "tok_embedding" in key: - key = "embedding.weight" - - elif "norm" in key: - key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") - - elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: - key = key.replace("wq", "query_proj") - key = key.replace("wk", "key_proj") - key = key.replace("wv", "value_proj") - key = key.replace("wo", "out_proj") - - elif "w1" in key or "w2" in key or "w3" in key: - # The FFN is a separate submodule in PyTorch - key = key.replace("feed_forward.w1", "linear1") - key = key.replace("feed_forward.w3", "linear2") - key = key.replace("feed_forward.w2", "linear3") - - elif "output" in key: - key = key.replace("output", "out_proj") - - elif "rope" in key: - return None, None - - return ( - key, - value.numpy() - if value.dtype != torch.bfloat16 - else value.to(torch.float32).numpy(), - ) - - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") - parser.add_argument("torch_weights") - parser.add_argument("output_file") + parser = argparse.ArgumentParser( + description="Convert Mistral or Llama models to MLX.", + ) + parser.add_argument( + "--torch_model", + type=str, + default="mistral-7B-v0.1/", + help="The torch model directory", + ) + parser.add_argument( + "--mlx_model", + type=str, + default="mlx-mistral-7B-v0.1/", + help="The directory to store the mlx model", + ) args = parser.parse_args() - state = torch.load(args.torch_weights) + torch_path = Path(args.torch_model) + if not os.path.exists(args.mlx_model): + os.makedirs(args.mlx_model) + mlx_path = Path(args.mlx_model) + + state = torch.load(str(torch_path / "consolidated.00.pth")) np.savez( - args.output_file, - **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} + str(mlx_path / "weights.npz"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()} ) + + # Copy the tokenizer + shutil.copyfile( + str(torch_path / "tokenizer.model"), + str(mlx_path / "tokenizer.model"), + ) + + # Copy the params + with open(torch_path / "params.json", "r") as f: + config = json.loads(f.read()) + if "sliding_window" in config: + config.pop("sliding_window") + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape + with open(mlx_path / "params.json", "w") as outfile: + json.dump(config, outfile) + + diff --git a/lora/llama.py b/lora/llama.py deleted file mode 100644 index aa59f919..00000000 --- a/lora/llama.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import math - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten - - -class LoRALinear(nn.Module): - @staticmethod - def from_linear(linear: nn.Linear, rank: int = 8): - input_dims, output_dims = linear.weight.shape - lora_lin = LoRALinear(input_dims, output_dims, rank) - lora_lin.linear = linear - return lora_lin - - def __init__( - self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False - ): - super().__init__() - - # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, lora_rank), - ) - self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) - - def __call__(self, x): - y = self.linear(x) - z = (x @ self.lora_a) @ self.lora_b - return y + 2.0 * z - - -class LlamaAttention(nn.Module): - def __init__(self, dims: int, num_heads: int): - super().__init__() - - self.num_heads = num_heads - - self.rope = nn.RoPE(dims // num_heads, traditional=True) - - self.query_proj = nn.Linear(dims, dims, bias=False) - self.key_proj = nn.Linear(dims, dims, bias=False) - self.value_proj = nn.Linear(dims, dims, bias=False) - self.out_proj = nn.Linear(dims, dims, bias=False) - - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - # Note that we return the keys and values to possibly be used as a cache - return self.out_proj(values_hat), (keys, values) - - -class LlamaEncoderLayer(nn.Module): - def __init__(self, dims: int, mlp_dims: int, num_heads: int): - super().__init__() - - self.attention = LlamaAttention(dims, num_heads) - - self.norm1 = nn.RMSNorm(dims) - self.norm2 = nn.RMSNorm(dims) - - self.linear1 = nn.Linear(dims, mlp_dims, bias=False) - self.linear2 = nn.Linear(dims, mlp_dims, bias=False) - self.linear3 = nn.Linear(mlp_dims, dims, bias=False) - - def __call__(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y - - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = a * mx.sigmoid(a) * b - y = self.linear3(y) - x = x + y - - return x, cache - - -class Llama(nn.Module): - def __init__( - self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int - ): - super().__init__() - - self.embedding = nn.Embedding(vocab_size, dims) - self.layers = [ - LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) - ] - self.norm = nn.RMSNorm(dims) - self.out_proj = nn.Linear(dims, vocab_size, bias=False) - - def __call__(self, x): - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - - x = self.embedding(x) - - for l in self.layers: - x, _ = l(x, mask) - - x = self.norm(x) - - return self.out_proj(x) - - def generate(self, x, temp=1.0): - cache = [] - try: - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.embedding.weight.dtype) - - # First we process the prompt x the same was as in __call__ but - # save the caches in cache - x = self.embedding(x) - for l in self.layers: - x, c = l(x, mask=mask) - # We store the per layer cache in a simple python list - cache.append(c) - x = self.norm(x) - # We only care about the last logits that generate the next token - y = self.out_proj(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) - - # y now has size [1] - yield y - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = y[:, None] - - x = self.embedding(x) - for i in range(len(cache)): - # We are overwriting the arrays in the cache list. When - # the computation will happen, MLX will be discarding the - # old cache the moment it is not needed anymore. - x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) - x = self.norm(x) - y = self.out_proj(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) - - yield y - - finally: - del cache - - -def load_model(model_path): - weights = mx.load(model_path) - mlp_dims, dims = weights["layers.0.linear1.weight"].shape - num_heads = dims // 128 - num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1 - vocab_size = weights["out_proj.weight"].shape[-1] - model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads) - model.update(tree_unflatten(list(weights.items()))) - mx.eval(model.parameters()) - return model diff --git a/lora/lora.py b/lora/lora.py index f0afc261..a7dcdb30 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -1,28 +1,28 @@ # Copyright © 2023 Apple Inc. import argparse +import json import math import numpy as np +from pathlib import Path from sentencepiece import SentencePieceProcessor import time +from typing import Optional, Tuple, List import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -from mlx.utils import tree_flatten +from mlx.utils import tree_map, tree_flatten, tree_unflatten -from llama import LoRALinear, load_model +from models import ModelArgs, Model, LoRALinear import wikisql def build_parser(): - parser = argparse.ArgumentParser(description="Llama LoRA finetuning") + parser = argparse.ArgumentParser(description="LoRA finetuning with Llama or Mistral") parser.add_argument( - "--model", required=True, help="The model file containing MLX weights" - ) - parser.add_argument( - "--tokenizer", required=True, help="The sentencepiece tokenizer" + "--model", required=True, help="A path to the model files containing the tokenizer, weights, config." ) # Generation args parser.add_argument( @@ -73,6 +73,12 @@ def build_parser(): default=200, help="Number of training steps between validations.", ) + parser.add_argument( + "--resume_adapter_file", + type=str, + default=None, + help="Load path to resume training with the given adapter weights.", + ) parser.add_argument( "--adapter_file", type=str, @@ -94,9 +100,30 @@ def build_parser(): return parser +class Tokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + self._sep = "▁" + assert self._model.vocab_size() == self._model.get_piece_size() + + def encode(self, s: str) -> List[int]: + return [self._model.bos_id(), *self._model.encode(s)] + + def decode(self, t: List[int]) -> str: + out = self._model.decode(t) + if t and self._model.id_to_piece(t[0])[0] == self._sep: + return " " + out + return out + + @property + def vocab_size(self) -> int: + return self._model.vocab_size() + + def loss(model, inputs, targets, lengths): # Run model on inputs - logits = model(inputs) + logits, _ = model(inputs) # Mask padding tokens length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] @@ -117,7 +144,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True): # Collect batches from dataset for i in range(0, len(indices) - batch_size + 1, batch_size): # Encode batch - batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)]) + batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)] lengths = [len(x) for x in batch] # Pad to the max length @@ -195,40 +222,55 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): def generate(model, prompt, tokenizer, args): - # Encode prompt - x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)]) + print(args.prompt, end="", flush=True) + prompt = mx.array(tokenizer.encode(args.prompt)) + + def generate_step(): + temp = args.temp + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt[None]) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache) + y = sample(logits.squeeze(1)) + yield y - skip = 0 - prompt_processing = None tokens = [] - - # Genertation loop - start = time.perf_counter() - for token in model.generate(x, args.temp): + for token, _ in zip(generate_step(), range(args.num_tokens)): tokens.append(token) - if len(tokens) == 1: - # Actually perform the computation to measure the prompt processing time - mx.eval(token) - prompt_processing = time.perf_counter() - start - - if len(tokens) >= args.num_tokens: - break - - if (len(tokens) % args.write_every) == 0: + if (len(tokens) % 10) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - skip = len(s) + print(s, end="", flush=True) + tokens = [] mx.eval(tokens) - full_gen = time.perf_counter() - start - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - print() - print(f"Prompt processing took: {prompt_processing:.3f} s") - print(f"Full generation took: {full_gen:.3f} s") + print(s, flush=True) + + +def load_model(folder: str, dtype=mx.float32): + model_path = Path(folder) + tokenizer = Tokenizer(str(model_path / "tokenizer.model")) + with open(model_path / "params.json", "r") as f: + config = json.loads(f.read()) + model_args = ModelArgs(**config) + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = tokenizer.vocab_size + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model = Model(model_args) + model.update(weights) + return model, tokenizer if __name__ == "__main__": @@ -237,17 +279,14 @@ if __name__ == "__main__": np.random.seed(args.seed) - print("Loading tokenizer") - tokenizer = SentencePieceProcessor(model_file=args.tokenizer) - print("Loading pretrained model") - model = load_model(args.model) + model, tokenizer = load_model(args.model) # Freeze all layers other than LORA linears model.freeze() for l in model.layers[16:32]: - l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj) - l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj) + l.attention.wq = LoRALinear.from_linear(l.attention.wq) + l.attention.wv = LoRALinear.from_linear(l.attention.wv) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") @@ -257,6 +296,11 @@ if __name__ == "__main__": print("Loading datasets") train_set, valid_set, test_set = wikisql.load() + # Resume training the given adapters. + if args.resume_adapter_file is not None: + print(f"Loading pretrained adapters from {args.resume_adapter_file}") + model.load_weights(args.resume_adapter_file) + if args.train: print("Training") opt = optim.Adam(learning_rate=args.learning_rate) @@ -287,5 +331,4 @@ if __name__ == "__main__": if args.prompt is not None: print("Generating") - generate(model, args.prompt, tokenizer, args) diff --git a/lora/models.py b/lora/models.py new file mode 100644 index 00000000..52024531 --- /dev/null +++ b/lora/models.py @@ -0,0 +1,193 @@ +# Copyright © 2023 Apple Inc. + +from dataclasses import dataclass +import math +from typing import Optional, Tuple, List + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map, tree_unflatten + + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + + +class LoRALinear(nn.Module): + @staticmethod + def from_linear(linear: nn.Linear, rank: int = 8): + output_dims, input_dims = linear.weight.shape + lora_lin = LoRALinear(input_dims, output_dims, rank) + lora_lin.linear = linear + return lora_lin + + def __init__( + self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, lora_rank), + ) + self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (x @ self.lora_a) @ self.lora_b + return y + 2.0 * z + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self.rope = nn.RoPE(args.head_dim, traditional=True) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads, L, -1]) + + if self.repeats > 1: + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + assert self.vocab_size > 0 + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.tok_embeddings(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.output(self.norm(h)), cache