generalize lora finetuning for llama and mistral

This commit is contained in:
Awni Hannun 2023-12-09 14:13:55 -08:00
parent 46c6bbe0a1
commit b8332a1e66
5 changed files with 354 additions and 293 deletions

View File

@ -1,7 +1,8 @@
# LoRA
This is an example of using MLX to fine-tune a Llama 7B[^llama] model with low
rank adaptation (LoRA)[^lora] for a target task.
This is an example of using MLX to fine-tune either a Llama 7B[^llama] or a
Mistral 7B[^mistral] model with low rank adaptation (LoRA)[^lora] for a target
task.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
@ -15,19 +16,27 @@ Install the dependencies:
pip install -r requirements.txt
```
Next, download and convert the model. If you do not have access to the model
weights you will need to [request
Next, download and convert the model. The Mistral weights can be downloaded with:
```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
tar -xf mistral-7B-v0.1.tar
```
If you do not have access to the Llama weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta.
Convert the weights with:
```
python convert.py <path_to_torch_weights> mlx_llama_7B.npz
python convert.py <path_to_torch_weights> <path_to_mlx_weights.npz>
```
## Run
#### Fine-tune
The main script is `lora.py`. To see a full list of options run
```
@ -37,28 +46,34 @@ python lora.py --help
To fine-tune a model use:
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--train \
--iters 600 \
--iters 600
```
Note, the model path should have the MLX weights, the tokenizer, and the
`params.json` configuration which will all be output by the `conver.py` script.
By default, the adapter weights are saved in `adapters.npz`. You can specify
the output location with `--adapter_file`.
#### Evaluate
To compute test set perplexity use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--adapter_file <path_to_adapters.npz> \
--test
```
#### Generate
For generation use
```
python lora.py --model mlx_llama_7B.npz \
--tokenizer tokenizer.model \
python lora.py --model <path_to_model> \
--adapter_file <path_to_adapters.npz> \
--num-tokens 50 \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
@ -87,4 +102,5 @@ The model trains at around 475 tokens per second on an M2 Ultra.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^llama]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^mistral]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@ -1,53 +1,61 @@
# Copyright © 2023 Apple Inc.
import argparse
from itertools import starmap
import json
import numpy as np
from pathlib import Path
import shutil
import os
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
parser = argparse.ArgumentParser(
description="Convert Mistral or Llama models to MLX.",
)
parser.add_argument(
"--torch_model",
type=str,
default="mistral-7B-v0.1/",
help="The torch model directory",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mlx-mistral-7B-v0.1/",
help="The directory to store the mlx model",
)
args = parser.parse_args()
state = torch.load(args.torch_weights)
torch_path = Path(args.torch_model)
if not os.path.exists(args.mlx_model):
os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model)
state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
str(mlx_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)
# Copy the tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
)
# Copy the params
with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read())
if "sliding_window" in config:
config.pop("sliding_window")
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape
with open(mlx_path / "params.json", "w") as outfile:
json.dump(config, outfile)

View File

@ -1,199 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
input_dims, output_dims = linear.weight.shape
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
def generate(self, x, temp=1.0):
cache = []
try:
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same was as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list
cache.append(c)
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
# y now has size [1]
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
yield y
finally:
del cache
def load_model(model_path):
weights = mx.load(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
num_heads = dims // 128
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
vocab_size = weights["out_proj.weight"].shape[-1]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model

View File

@ -1,28 +1,28 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
import math
import numpy as np
from pathlib import Path
from sentencepiece import SentencePieceProcessor
import time
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from mlx.utils import tree_map, tree_flatten, tree_unflatten
from llama import LoRALinear, load_model
from models import ModelArgs, Model, LoRALinear
import wikisql
def build_parser():
parser = argparse.ArgumentParser(description="Llama LoRA finetuning")
parser = argparse.ArgumentParser(description="LoRA finetuning with Llama or Mistral")
parser.add_argument(
"--model", required=True, help="The model file containing MLX weights"
)
parser.add_argument(
"--tokenizer", required=True, help="The sentencepiece tokenizer"
"--model", required=True, help="A path to the model files containing the tokenizer, weights, config."
)
# Generation args
parser.add_argument(
@ -73,6 +73,12 @@ def build_parser():
default=200,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume_adapter_file",
type=str,
default=None,
help="Load path to resume training with the given adapter weights.",
)
parser.add_argument(
"--adapter_file",
type=str,
@ -94,9 +100,30 @@ def build_parser():
return parser
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
self._sep = ""
assert self._model.vocab_size() == self._model.get_piece_size()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
out = self._model.decode(t)
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
@property
def vocab_size(self) -> int:
return self._model.vocab_size()
def loss(model, inputs, targets, lengths):
# Run model on inputs
logits = model(inputs)
logits, _ = model(inputs)
# Mask padding tokens
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
@ -117,7 +144,7 @@ def iterate_batches(dset, tokenizer, batch_size, shuffle=True):
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = tokenizer.encode([dset[indices[i + j]] for j in range(batch_size)])
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Pad to the max length
@ -195,40 +222,55 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
# Encode prompt
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(prompt)])
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
def generate_step():
temp = args.temp
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache)
y = sample(logits.squeeze(1))
yield y
skip = 0
prompt_processing = None
tokens = []
# Genertation loop
start = time.perf_counter()
for token in model.generate(x, args.temp):
for token, _ in zip(generate_step(), range(args.num_tokens)):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = time.perf_counter() - start
if len(tokens) >= args.num_tokens:
break
if (len(tokens) % args.write_every) == 0:
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)
print(s, end="", flush=True)
tokens = []
mx.eval(tokens)
full_gen = time.perf_counter() - start
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(f"Prompt processing took: {prompt_processing:.3f} s")
print(f"Full generation took: {full_gen:.3f} s")
print(s, flush=True)
def load_model(folder: str, dtype=mx.float32):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
model_args = ModelArgs(**config)
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Model(model_args)
model.update(weights)
return model, tokenizer
if __name__ == "__main__":
@ -237,17 +279,14 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading tokenizer")
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("Loading pretrained model")
model = load_model(args.model)
model, tokenizer = load_model(args.model)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[16:32]:
l.attention.query_proj = LoRALinear.from_linear(l.attention.query_proj)
l.attention.value_proj = LoRALinear.from_linear(l.attention.value_proj)
l.attention.wq = LoRALinear.from_linear(l.attention.wq)
l.attention.wv = LoRALinear.from_linear(l.attention.wv)
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
@ -257,6 +296,11 @@ if __name__ == "__main__":
print("Loading datasets")
train_set, valid_set, test_set = wikisql.load()
# Resume training the given adapters.
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file)
if args.train:
print("Training")
opt = optim.Adam(learning_rate=args.learning_rate)
@ -287,5 +331,4 @@ if __name__ == "__main__":
if args.prompt is not None:
print("Generating")
generate(model, args.prompt, tokenizer, args)

193
lora/models.py Normal file
View File

@ -0,0 +1,193 @@
# Copyright © 2023 Apple Inc.
from dataclasses import dataclass
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def __init__(
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache