mlx-examples/transformer_lm/main.py
Awni Hannun f45a1ab83c
Update a few examples to use compile (#420)
* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
2024-02-08 13:00:41 -08:00

221 lines
6.9 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import math
import time
from functools import partial
import datasets
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
class TransformerLM(nn.Module):
def __init__(
self,
vocab_size: int,
num_layers: int,
dims: int,
num_heads: int,
checkpoint: bool,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.pe = nn.SinusoidalPositionalEncoding(dims)
self.transformer = nn.TransformerEncoder(
num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
L = x.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
x = self.embedding(x)
x = x + self.pe(mx.arange(L))
x = self.transformer(x, mask)
return self.out_proj(x)
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def main(args):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.load_dataset(args.dataset)
# Initialize model:
model = TransformerLM(
len(vocab), args.num_blocks, args.dim, args.num_heads, args.checkpoint
)
mx.eval(model.parameters())
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay
)
def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
optimizer.update(model, grads)
return loss
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
loss = step(inputs, targets)
mx.eval(state)
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(model, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(model, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--dataset",
type=str,
default="ptb",
choices=["ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.",
)
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument(
"--checkpoint", action="store_true", help="Perform gradient checkpointing"
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=3e-4, help="SGD learning rate."
)
parser.add_argument(
"--weight_decay", type=float, default=1e-5, help="Set the weight decay"
)
parser.add_argument(
"--lr_warmup", type=int, default=200, help="LR linear warmup iterations"
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)