mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add grad checkpointing and PE in the transformer example (#387)
* Add grad checkpointing and PE in the transformer example * Remove other frameworks from LM example * Remove the other frameworks from MNIST example * Improve the transformer LM example * Fix black and change LR
This commit is contained in:
parent
ec14583c2a
commit
e9b32747b4
@ -1,83 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
import mnist
|
|
||||||
|
|
||||||
|
|
||||||
def init_model(key, num_layers, input_dim, hidden_dim, output_dim):
|
|
||||||
params = []
|
|
||||||
layer_sizes = [hidden_dim] * num_layers
|
|
||||||
for idim, odim in zip([input_dim] + layer_sizes, layer_sizes + [output_dim]):
|
|
||||||
key, wk = jax.random.split(key, 2)
|
|
||||||
W = 1e-2 * jax.random.normal(wk, (idim, odim))
|
|
||||||
b = jnp.zeros((odim,))
|
|
||||||
params.append((W, b))
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def feed_forward(params, X):
|
|
||||||
for W, b in params[:-1]:
|
|
||||||
X = jnp.maximum(X @ W + b, 0)
|
|
||||||
W, b = params[-1]
|
|
||||||
return X @ W + b
|
|
||||||
|
|
||||||
|
|
||||||
def loss_fn(params, X, y):
|
|
||||||
logits = feed_forward(params, X)
|
|
||||||
logits = jax.nn.log_softmax(logits, 1)
|
|
||||||
return -jnp.mean(logits[jnp.arange(y.size), y])
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def eval_fn(params, X, y):
|
|
||||||
logits = feed_forward(params, X)
|
|
||||||
return jnp.mean(jnp.argmax(logits, axis=1) == y)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_iterate(key, batch_size, X, y):
|
|
||||||
perm = jax.random.permutation(key, y.size)
|
|
||||||
for s in range(0, y.size, batch_size):
|
|
||||||
ids = perm[s : s + batch_size]
|
|
||||||
yield X[ids], y[ids]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
seed = 0
|
|
||||||
num_layers = 2
|
|
||||||
hidden_dim = 32
|
|
||||||
num_classes = 10
|
|
||||||
batch_size = 256
|
|
||||||
num_epochs = 10
|
|
||||||
learning_rate = 1e-1
|
|
||||||
dataset = "mnist"
|
|
||||||
|
|
||||||
# Load the data
|
|
||||||
train_images, train_labels, test_images, test_labels = getattr(mnist, dataset)()
|
|
||||||
# Load the model
|
|
||||||
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
|
||||||
params = init_model(
|
|
||||||
subkey, num_layers, train_images.shape[-1], hidden_dim, num_classes
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
|
|
||||||
update_fn = jax.jit(
|
|
||||||
functools.partial(jax.tree_map, lambda p, g: p - learning_rate * g)
|
|
||||||
)
|
|
||||||
|
|
||||||
for e in range(num_epochs):
|
|
||||||
tic = time.perf_counter()
|
|
||||||
key, subkey = jax.random.split(key)
|
|
||||||
for X, y in batch_iterate(subkey, batch_size, train_images, train_labels):
|
|
||||||
loss, grads = loss_and_grad_fn(params, X, y)
|
|
||||||
params = update_fn(params, grads)
|
|
||||||
accuracy = eval_fn(params, test_images, test_labels)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
|
||||||
f" Time {toc - tic:.3f} (s)"
|
|
||||||
)
|
|
@ -1,100 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import mnist
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
|
||||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
|
|
||||||
super().__init__()
|
|
||||||
layer_sizes = [hidden_dim] * num_layers
|
|
||||||
self.layers = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
torch.nn.Linear(idim, odim)
|
|
||||||
for idim, odim in zip(
|
|
||||||
[input_dim] + layer_sizes, layer_sizes + [output_dim]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.layers[0](x)
|
|
||||||
for l in self.layers[1:]:
|
|
||||||
x = l(x.relu())
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def loss_fn(model, X, y):
|
|
||||||
logits = model(X)
|
|
||||||
return torch.nn.functional.cross_entropy(logits, y)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def eval_fn(model, X, y):
|
|
||||||
logits = model(X)
|
|
||||||
return torch.mean((logits.argmax(-1) == y).float())
|
|
||||||
|
|
||||||
|
|
||||||
def batch_iterate(batch_size, X, y, device):
|
|
||||||
perm = torch.randperm(len(y), device=device)
|
|
||||||
for s in range(0, len(y), batch_size):
|
|
||||||
ids = perm[s : s + batch_size]
|
|
||||||
yield X[ids], y[ids]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset",
|
|
||||||
type=str,
|
|
||||||
default="mnist",
|
|
||||||
choices=["mnist", "fashion_mnist"],
|
|
||||||
help="The dataset to use.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.gpu:
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
device = "cpu"
|
|
||||||
else:
|
|
||||||
device = "mps"
|
|
||||||
seed = 0
|
|
||||||
num_layers = 2
|
|
||||||
hidden_dim = 32
|
|
||||||
num_classes = 10
|
|
||||||
batch_size = 256
|
|
||||||
num_epochs = 10
|
|
||||||
learning_rate = 1e-1
|
|
||||||
|
|
||||||
# Load the data
|
|
||||||
def to_tensor(x):
|
|
||||||
if x.dtype != "uint32":
|
|
||||||
return torch.from_numpy(x).to(device)
|
|
||||||
else:
|
|
||||||
return torch.from_numpy(x.astype(int)).to(device)
|
|
||||||
|
|
||||||
train_images, train_labels, test_images, test_labels = map(
|
|
||||||
to_tensor, getattr(mnist, args.dataset)()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
|
||||||
opt = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
|
|
||||||
|
|
||||||
for e in range(num_epochs):
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for X, y in batch_iterate(batch_size, train_images, train_labels, device):
|
|
||||||
opt.zero_grad()
|
|
||||||
loss_fn(model, X, y).backward()
|
|
||||||
opt.step()
|
|
||||||
accuracy = eval_fn(model, test_images, test_labels)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
|
||||||
f" Time {toc - tic:.3f} (s)"
|
|
||||||
)
|
|
@ -10,5 +10,3 @@ python main.py --gpu
|
|||||||
```
|
```
|
||||||
|
|
||||||
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
|
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
|
||||||
|
|
||||||
To run the PyTorch, Jax or TensorFlow examples install the respective framework.
|
|
||||||
|
@ -1,305 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
|
||||||
from tree_utils import tree_flatten
|
|
||||||
|
|
||||||
"""
|
|
||||||
Some TODOs for this model:
|
|
||||||
- Positional encodings
|
|
||||||
- Dropout
|
|
||||||
- Adam optimizer
|
|
||||||
- Option for bigger datasets (wikitext / librispeech text < c4 < ...)
|
|
||||||
"""
|
|
||||||
|
|
||||||
RuntimeConfig = namedtuple("RuntimeConfig", "num_heads")
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_init(key, num_embeddings, embed_dim):
|
|
||||||
return jax.random.uniform(
|
|
||||||
key, (num_embeddings, embed_dim), minval=-1e-1, maxval=1e-1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_apply(params, X):
|
|
||||||
return params.take(X, axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
def dense_init(key, in_dim, out_dim, bias=True):
|
|
||||||
k1, k2 = jax.random.split(key)
|
|
||||||
scale = math.sqrt(1 / in_dim)
|
|
||||||
params = [jax.random.uniform(k1, (in_dim, out_dim), minval=-scale, maxval=scale)]
|
|
||||||
if bias:
|
|
||||||
params.append(jax.random.uniform(k2, (out_dim,), minval=-scale, maxval=scale))
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def dense_apply(params, X):
|
|
||||||
X = X @ params[0]
|
|
||||||
if len(params) == 2:
|
|
||||||
X = X + params[1]
|
|
||||||
return X
|
|
||||||
|
|
||||||
|
|
||||||
def layernorm_init(key, dim):
|
|
||||||
return [jnp.zeros((dim,)), jnp.ones((dim,))]
|
|
||||||
|
|
||||||
|
|
||||||
def layernorm_apply(params, X, epsilon=1e-6):
|
|
||||||
means = jnp.mean(X, axis=-1, keepdims=True)
|
|
||||||
var = jnp.var(X, axis=-1, keepdims=True)
|
|
||||||
X = (X - means) / jnp.sqrt(var + epsilon)
|
|
||||||
beta, gamma = params
|
|
||||||
return gamma * X + beta
|
|
||||||
|
|
||||||
|
|
||||||
def mlpblock_init(key, dim):
|
|
||||||
k1, k2 = jax.random.split(key)
|
|
||||||
return {
|
|
||||||
"dense1": dense_init(k1, dim, 4 * dim),
|
|
||||||
"dense2": dense_init(k2, 4 * dim, dim),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def mlpblock_apply(params, X):
|
|
||||||
X = dense_apply(params["dense1"], X)
|
|
||||||
X = jnp.maximum(X, 0)
|
|
||||||
# TODO dropout option here
|
|
||||||
return dense_apply(params["dense2"], X)
|
|
||||||
|
|
||||||
|
|
||||||
def selfattention_init(key, dim):
|
|
||||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
|
||||||
return {
|
|
||||||
"Q": dense_init(k1, dim, dim, bias=False),
|
|
||||||
"K": dense_init(k2, dim, dim, bias=False),
|
|
||||||
"V": dense_init(k3, dim, dim, bias=False),
|
|
||||||
"out": dense_init(k4, dim, dim, bias=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def selfattention_apply(params, num_heads, X, mask):
|
|
||||||
queries = dense_apply(params["Q"], X)
|
|
||||||
keys = dense_apply(params["K"], X)
|
|
||||||
values = dense_apply(params["V"], X)
|
|
||||||
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ jnp.transpose(keys, (0, 1, 3, 2))
|
|
||||||
scores = jax.nn.softmax(scores + mask, axis=-1)
|
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return dense_apply(params["out"], values_hat)
|
|
||||||
|
|
||||||
|
|
||||||
def transformer_init(key, token_set_size, num_blocks, dim):
|
|
||||||
key, ek = jax.random.split(key)
|
|
||||||
params = {"embedding": embedding_init(ek, token_set_size, dim)}
|
|
||||||
transformer_blocks = []
|
|
||||||
for b in range(num_blocks):
|
|
||||||
key, k1, k2, k3, k4 = jax.random.split(key, 5)
|
|
||||||
transformer_blocks.append(
|
|
||||||
{
|
|
||||||
"ln1": layernorm_init(k1, dim),
|
|
||||||
"ln2": layernorm_init(k2, dim),
|
|
||||||
"selfattention": selfattention_init(k3, dim),
|
|
||||||
"mlpblock": mlpblock_init(k4, dim),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
params["transformer_blocks"] = transformer_blocks
|
|
||||||
params["output"] = dense_init(key, dim, token_set_size)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N):
|
|
||||||
indices = jnp.arange(N)
|
|
||||||
mask = jnp.reshape(indices, (-1, 1)) < jnp.reshape(indices, (1, -1))
|
|
||||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
|
||||||
mask = mask.astype(jnp.float32) * -1e9
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def transformer_apply(params, static_params, inputs):
|
|
||||||
mask = create_additive_causal_mask(inputs.shape[1])
|
|
||||||
X = embedding_apply(params["embedding"], inputs)
|
|
||||||
for block in params["transformer_blocks"]:
|
|
||||||
out = layernorm_apply(block["ln1"], X)
|
|
||||||
out = selfattention_apply(
|
|
||||||
block["selfattention"], static_params.num_heads, out, mask
|
|
||||||
)
|
|
||||||
X = X + out
|
|
||||||
out = layernorm_apply(block["ln2"], X)
|
|
||||||
out = mlpblock_apply(block["mlpblock"], out)
|
|
||||||
X = X + out
|
|
||||||
return dense_apply(params["output"], X)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.partial(jax.jit, static_argnames=["static_params", "reduce"])
|
|
||||||
def loss_fn(params, static_params, inputs, targets, reduce=True):
|
|
||||||
logits = transformer_apply(params, static_params, inputs)
|
|
||||||
logits = jax.nn.log_softmax(logits, axis=-1)
|
|
||||||
sample_indices = jnp.arange(targets.shape[0])[:, None]
|
|
||||||
token_indices = jnp.arange(targets.shape[1])[None, :]
|
|
||||||
losses = -logits[sample_indices, token_indices, targets]
|
|
||||||
return jnp.mean(losses) if reduce else losses.mean(-1)
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
|
||||||
tokens = dataset.size
|
|
||||||
window_size = context_size + 1 # include target
|
|
||||||
samples = tokens - window_size + 1
|
|
||||||
X = np.lib.stride_tricks.as_strided(
|
|
||||||
dataset,
|
|
||||||
shape=(samples, window_size),
|
|
||||||
strides=(dataset.itemsize, dataset.itemsize),
|
|
||||||
)
|
|
||||||
return X[:, :-1], X[:, 1:]
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(key, batch_size, context_size, dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
s = 0
|
|
||||||
while True:
|
|
||||||
if s == 0:
|
|
||||||
# Reset permutation:
|
|
||||||
key, subkey = jax.random.split(key)
|
|
||||||
perm = jax.random.permutation(subkey, inputs.shape[0])
|
|
||||||
ids = perm[s : s + batch_size]
|
|
||||||
yield inputs[ids], targets[ids]
|
|
||||||
s += batch_size
|
|
||||||
if s >= inputs.shape[0]:
|
|
||||||
s = 0
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
batch_size = args.batch_size
|
|
||||||
context_size = args.context_size
|
|
||||||
steps_per_eval = args.steps_per_eval
|
|
||||||
steps_per_report = args.steps_per_report
|
|
||||||
config = RuntimeConfig(args.num_heads)
|
|
||||||
|
|
||||||
# Load vocab and dataset:
|
|
||||||
vocab, train, valid, test = datasets.ptb()
|
|
||||||
|
|
||||||
# Initialize model:
|
|
||||||
key, subkey = jax.random.split(jax.random.PRNGKey(args.seed))
|
|
||||||
params = transformer_init(subkey, len(vocab), args.num_blocks, args.dim)
|
|
||||||
nparams = sum(x.size for k, x in tree_flatten(params) if "embedding" not in k)
|
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
|
||||||
|
|
||||||
loss_and_grad_fn = jax.jit(
|
|
||||||
jax.value_and_grad(loss_fn), static_argnames=["static_params"]
|
|
||||||
)
|
|
||||||
update_fn = jax.jit(
|
|
||||||
functools.partial(jax.tree_map, lambda p, g: p - args.learning_rate * g)
|
|
||||||
)
|
|
||||||
|
|
||||||
def eval_fn(params, dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
loss = 0
|
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
|
||||||
losses = loss_fn(params, config, bx, by, reduce=False)
|
|
||||||
loss += jnp.sum(losses)
|
|
||||||
return loss / len(targets)
|
|
||||||
|
|
||||||
train_iterator = iterate_batches(subkey, batch_size, context_size, train)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
|
||||||
loss, grads = loss_and_grad_fn(params, config, inputs, targets)
|
|
||||||
losses.append(loss.item())
|
|
||||||
params = update_fn(params, grads)
|
|
||||||
if (it + 1) % steps_per_report == 0:
|
|
||||||
train_loss = np.mean(losses)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
|
||||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
|
||||||
)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if (it + 1) % steps_per_eval == 0:
|
|
||||||
val_loss = eval_fn(params, valid)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: "
|
|
||||||
f"Val loss {val_loss:.3f}, "
|
|
||||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
|
||||||
f"Val took {(toc - tic):.3f}s, "
|
|
||||||
)
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if args.eval_test:
|
|
||||||
test_loss = eval_fn(params, test)
|
|
||||||
test_ppl = math.exp(test_loss)
|
|
||||||
print(f"Test loss {test_loss.item():.3f}, Test ppl {test_ppl:.3f}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with Jax.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed", type=int, default=0, help="Seed for numpy and Jax RNGs."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--context_size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Context size in tokens of the model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dim",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Dimensionality of embeddings and hidden layers.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_heads",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of heads used for multi-head attention",
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_report",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of training steps between loss reporting.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_eval",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of training steps between validations.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_test",
|
|
||||||
action="store_true",
|
|
||||||
help="Evaluate on the test set after training",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
@ -12,16 +12,28 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
|
|
||||||
class TransformerLM(nn.Module):
|
class TransformerLM(nn.Module):
|
||||||
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
checkpoint: bool,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedding = nn.Embedding(vocab_size, dims)
|
self.embedding = nn.Embedding(vocab_size, dims)
|
||||||
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
|
self.pe = nn.SinusoidalPositionalEncoding(dims)
|
||||||
|
self.transformer = nn.TransformerEncoder(
|
||||||
|
num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
|
||||||
|
)
|
||||||
self.out_proj = nn.Linear(dims, vocab_size)
|
self.out_proj = nn.Linear(dims, vocab_size)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
L = x.shape[1]
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
|
||||||
x = self.embedding(x)
|
x = self.embedding(x)
|
||||||
|
x = x + self.pe(mx.arange(L))
|
||||||
x = self.transformer(x, mask)
|
x = self.transformer(x, mask)
|
||||||
return self.out_proj(x)
|
return self.out_proj(x)
|
||||||
|
|
||||||
@ -67,14 +79,18 @@ def main(args):
|
|||||||
vocab, train, valid, test = datasets.load_dataset(args.dataset)
|
vocab, train, valid, test = datasets.load_dataset(args.dataset)
|
||||||
|
|
||||||
# Initialize model:
|
# Initialize model:
|
||||||
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
|
model = TransformerLM(
|
||||||
|
len(vocab), args.num_blocks, args.dim, args.num_heads, args.checkpoint
|
||||||
|
)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
nparams = sum(
|
nparams = sum(
|
||||||
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
|
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
|
||||||
)
|
)
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||||
|
|
||||||
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
optimizer = optim.AdamW(
|
||||||
|
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||||
|
|
||||||
def eval_fn(model, dataset):
|
def eval_fn(model, dataset):
|
||||||
@ -93,7 +109,9 @@ def main(args):
|
|||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||||
inputs, targets = map(mx.array, (inputs, targets))
|
inputs, targets = map(mx.array, (inputs, targets))
|
||||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||||
model.update(optimizer.apply_gradients(grads, model))
|
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
del grads
|
||||||
mx.eval(loss, model.parameters())
|
mx.eval(loss, model.parameters())
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
if (it + 1) % steps_per_report == 0:
|
if (it + 1) % steps_per_report == 0:
|
||||||
@ -156,12 +174,21 @@ if __name__ == "__main__":
|
|||||||
default=16,
|
default=16,
|
||||||
help="Number of heads used for multi-head attention",
|
help="Number of heads used for multi-head attention",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint", action="store_true", help="Perform gradient checkpointing"
|
||||||
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
"--learning_rate", type=float, default=3e-4, help="SGD learning rate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay", type=float, default=1e-5, help="Set the weight decay"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_warmup", type=int, default=200, help="LR linear warmup iterations"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--steps_per_report",
|
"--steps_per_report",
|
||||||
|
1
transformer_lm/requirements.txt
Normal file
1
transformer_lm/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
mlx >= 0.12
|
@ -1,250 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
|
||||||
tokens = dataset.size
|
|
||||||
window_size = context_size + 1 # include target
|
|
||||||
samples = tokens - window_size + 1
|
|
||||||
X = np.lib.stride_tricks.as_strided(
|
|
||||||
dataset,
|
|
||||||
shape=(samples, window_size),
|
|
||||||
strides=(dataset.itemsize, dataset.itemsize),
|
|
||||||
)
|
|
||||||
return X[:, :-1], X[:, 1:]
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(batch_size, context_size, dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
s = 0
|
|
||||||
while True:
|
|
||||||
if s == 0:
|
|
||||||
# Reset permutation:
|
|
||||||
perm = np.random.permutation(inputs.shape[0])
|
|
||||||
ids = perm[s : s + batch_size]
|
|
||||||
yield inputs[ids], targets[ids]
|
|
||||||
s += batch_size
|
|
||||||
if s + batch_size >= inputs.shape[0]:
|
|
||||||
s = 0
|
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N):
|
|
||||||
indices = tf.range(N)
|
|
||||||
mask = tf.reshape(indices, (-1, 1)) < tf.reshape(indices, (1, -1))
|
|
||||||
return tf.cast(mask, tf.dtypes.float32) * -1e9
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(tf.keras.layers.Layer):
|
|
||||||
def __init__(self, num_heads, model_dims, context_size):
|
|
||||||
super().__init__()
|
|
||||||
self.Wq = tf.keras.layers.Dense(model_dims, use_bias=False)
|
|
||||||
self.Wk = tf.keras.layers.Dense(model_dims, use_bias=False)
|
|
||||||
self.Wv = tf.keras.layers.Dense(model_dims, use_bias=False)
|
|
||||||
self.Wo = tf.keras.layers.Dense(model_dims, use_bias=False)
|
|
||||||
self.causal_mask = create_additive_causal_mask(context_size)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = model_dims // num_heads
|
|
||||||
self.scale = tf.constant(1.0 / math.sqrt(self.head_dim))
|
|
||||||
|
|
||||||
def call(self, x):
|
|
||||||
queries = self.Wq(x)
|
|
||||||
keys = self.Wk(x)
|
|
||||||
values = self.Wv(x)
|
|
||||||
|
|
||||||
B, L, D = x.shape
|
|
||||||
queries = tf.transpose(
|
|
||||||
tf.reshape(queries, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
|
||||||
)
|
|
||||||
keys = tf.transpose(
|
|
||||||
tf.reshape(keys, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
|
||||||
)
|
|
||||||
values = tf.transpose(
|
|
||||||
tf.reshape(values, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = (self.scale * queries) @ tf.transpose(keys, (0, 1, 3, 2))
|
|
||||||
scores = tf.nn.softmax(scores + self.causal_mask, axis=-1)
|
|
||||||
values = tf.matmul(scores, values)
|
|
||||||
values_hat = tf.reshape(tf.transpose(values, perm=(0, 2, 1, 3)), (B, L, -1))
|
|
||||||
|
|
||||||
return self.Wo(values_hat)
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(tf.keras.layers.Layer):
|
|
||||||
def __init__(self, num_heads, model_dims, context_size):
|
|
||||||
super().__init__()
|
|
||||||
self._ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
|
||||||
|
|
||||||
self._self_attn = SelfAttention(num_heads, model_dims, context_size)
|
|
||||||
|
|
||||||
self._ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
|
||||||
|
|
||||||
self._mlp = tf.keras.Sequential(
|
|
||||||
[
|
|
||||||
tf.keras.layers.Dense(4 * model_dims, activation="relu"),
|
|
||||||
tf.keras.layers.Dense(model_dims),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def call(self, x):
|
|
||||||
x = x + self._self_attn(self._ln1(x))
|
|
||||||
x = x + self._mlp(self._ln2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerLM(tf.keras.Model):
|
|
||||||
def __init__(self, vocab_size, num_layers, num_heads, model_dims, context_size):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.embedding = tf.keras.layers.Embedding(vocab_size, model_dims)
|
|
||||||
self.transformer = tf.keras.Sequential(
|
|
||||||
[
|
|
||||||
EncoderLayer(num_heads, model_dims, context_size)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.projection = tf.keras.layers.Dense(vocab_size)
|
|
||||||
|
|
||||||
def call(self, x):
|
|
||||||
x = self.embedding(x)
|
|
||||||
x = self.transformer(x)
|
|
||||||
x = self.projection(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def main(args, device):
|
|
||||||
with tf.device(device):
|
|
||||||
batch_size = args.batch_size
|
|
||||||
context_size = args.context_size
|
|
||||||
steps_per_eval = args.steps_per_eval
|
|
||||||
steps_per_report = args.steps_per_report
|
|
||||||
|
|
||||||
# Load vocab and dataset:
|
|
||||||
vocab, train, valid, test = datasets.ptb()
|
|
||||||
|
|
||||||
# Initialize model:
|
|
||||||
transformer = TransformerLM(
|
|
||||||
len(vocab), args.num_blocks, args.num_heads, args.dim, context_size
|
|
||||||
)
|
|
||||||
transformer.compile(
|
|
||||||
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=args.learning_rate),
|
|
||||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|
||||||
)
|
|
||||||
transformer.build((batch_size, context_size))
|
|
||||||
nparams = sum(
|
|
||||||
np.prod(p.shape) for p in transformer.trainable_weights[1:]
|
|
||||||
) # [1:] to skip the embedding
|
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
|
||||||
|
|
||||||
def eval_fn(dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
loss = 0
|
|
||||||
n_batches = 0
|
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
|
||||||
if s + batch_size >= targets.shape[0]:
|
|
||||||
s = targets.shape[0] - 1 - batch_size
|
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
|
||||||
bx, by = map(
|
|
||||||
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
|
|
||||||
[bx, by],
|
|
||||||
)
|
|
||||||
loss += transformer.test_on_batch(bx, by)
|
|
||||||
n_batches += 1
|
|
||||||
return loss / n_batches
|
|
||||||
|
|
||||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
|
||||||
inputs, targets = map(
|
|
||||||
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
|
|
||||||
[inputs, targets],
|
|
||||||
)
|
|
||||||
loss = transformer.train_on_batch(inputs, targets)
|
|
||||||
losses.append(loss)
|
|
||||||
if (it + 1) % steps_per_report == 0:
|
|
||||||
train_loss = np.mean(losses)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
|
||||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
|
||||||
)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if (it + 1) % steps_per_eval == 0:
|
|
||||||
val_loss = eval_fn(valid)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: "
|
|
||||||
f"Val loss {val_loss:.3f}, "
|
|
||||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
|
||||||
f"Val took {(toc - tic):.3f}s, "
|
|
||||||
)
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if args.eval_test:
|
|
||||||
test_loss = eval_fn(test)
|
|
||||||
test_ppl = math.exp(test_loss)
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
|
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--context_size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Context size in tokens of the model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dim",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Dimensionality of embeddings and hidden layers.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_heads",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of heads used for multi-head attention",
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_report",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of training steps between loss reporting.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_eval",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of training steps between validations.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_test",
|
|
||||||
action="store_true",
|
|
||||||
help="Evaluate on the test set after training",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args, device="/GPU:0" if args.gpu else "/CPU:0")
|
|
@ -1,198 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
|
||||||
tokens = dataset.size
|
|
||||||
window_size = context_size + 1 # include target
|
|
||||||
samples = tokens - window_size + 1
|
|
||||||
X = np.lib.stride_tricks.as_strided(
|
|
||||||
dataset,
|
|
||||||
shape=(samples, window_size),
|
|
||||||
strides=(dataset.itemsize, dataset.itemsize),
|
|
||||||
)
|
|
||||||
return X[:, :-1], X[:, 1:]
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(batch_size, context_size, dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
s = 0
|
|
||||||
while True:
|
|
||||||
if s == 0:
|
|
||||||
# Reset permutation:
|
|
||||||
perm = np.random.permutation(inputs.shape[0])
|
|
||||||
ids = perm[s : s + batch_size]
|
|
||||||
yield inputs[ids], targets[ids]
|
|
||||||
s += batch_size
|
|
||||||
if s >= inputs.shape[0]:
|
|
||||||
s = 0
|
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N, device):
|
|
||||||
# torch.nn.Transformer.generate_square_subsequent_mask
|
|
||||||
# gives NaNs with `device="mps"`
|
|
||||||
indices = torch.arange(N, device=device)
|
|
||||||
mask = indices.reshape((-1, 1)) < indices.reshape((1, -1))
|
|
||||||
return mask.to(torch.float32) * -1e9
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerLM(torch.nn.Module):
|
|
||||||
def __init__(self, vocab_size, num_layers, num_heads, model_dims):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.embedding = torch.nn.Embedding(vocab_size, model_dims)
|
|
||||||
self.transformer = torch.nn.TransformerEncoder(
|
|
||||||
torch.nn.TransformerEncoderLayer(
|
|
||||||
model_dims,
|
|
||||||
num_heads,
|
|
||||||
4 * model_dims,
|
|
||||||
dropout=0.0,
|
|
||||||
norm_first=True,
|
|
||||||
batch_first=True,
|
|
||||||
),
|
|
||||||
num_layers,
|
|
||||||
)
|
|
||||||
self.projection = torch.nn.Linear(model_dims, vocab_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
mask = create_additive_causal_mask(x.shape[1], device=x.device)
|
|
||||||
x = self.embedding(x)
|
|
||||||
x = self.transformer(x, mask=mask)
|
|
||||||
x = self.projection(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def main(args, device):
|
|
||||||
batch_size = args.batch_size
|
|
||||||
context_size = args.context_size
|
|
||||||
steps_per_eval = args.steps_per_eval
|
|
||||||
steps_per_report = args.steps_per_report
|
|
||||||
|
|
||||||
# Load vocab and dataset:
|
|
||||||
vocab, train, valid, test = datasets.ptb()
|
|
||||||
|
|
||||||
# Initialize model:
|
|
||||||
transformer = TransformerLM(len(vocab), args.num_blocks, args.num_heads, args.dim)
|
|
||||||
transformer = transformer.to(device)
|
|
||||||
optim = torch.optim.SGD(transformer.parameters(), lr=args.learning_rate, momentum=0)
|
|
||||||
nparams = sum(
|
|
||||||
p.numel() for n, p in transformer.named_parameters() if "embedding" not in n
|
|
||||||
)
|
|
||||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def eval_fn(dataset):
|
|
||||||
inputs, targets = to_samples(context_size, dataset)
|
|
||||||
loss = 0
|
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
|
||||||
bx, by = map(lambda x: torch.from_numpy(x.astype(int)).to(device), [bx, by])
|
|
||||||
logits = transformer(bx)
|
|
||||||
losses = torch.nn.functional.cross_entropy(
|
|
||||||
logits.flatten(0, 1), by.flatten(), reduction="none"
|
|
||||||
)
|
|
||||||
losses = losses.view(-1, by.shape[-1]).mean(-1)
|
|
||||||
loss += losses.sum().item()
|
|
||||||
return loss / len(targets)
|
|
||||||
|
|
||||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
|
||||||
inputs, targets = map(
|
|
||||||
lambda x: torch.from_numpy(x.astype(int)).to(device), [inputs, targets]
|
|
||||||
)
|
|
||||||
optim.zero_grad()
|
|
||||||
logits = transformer(inputs)
|
|
||||||
loss = torch.nn.functional.cross_entropy(
|
|
||||||
logits.flatten(0, 1), targets.flatten()
|
|
||||||
)
|
|
||||||
loss.backward()
|
|
||||||
optim.step()
|
|
||||||
losses.append(loss.item())
|
|
||||||
if (it + 1) % steps_per_report == 0:
|
|
||||||
train_loss = np.mean(losses)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
|
||||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
|
||||||
)
|
|
||||||
losses = []
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if (it + 1) % steps_per_eval == 0:
|
|
||||||
val_loss = eval_fn(valid)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: "
|
|
||||||
f"Val loss {val_loss:.3f}, "
|
|
||||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
|
||||||
f"Val took {(toc - tic):.3f}s, "
|
|
||||||
)
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
if args.eval_test:
|
|
||||||
test_loss = eval_fn(test)
|
|
||||||
test_ppl = math.exp(test_loss)
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
|
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--context_size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Context size in tokens of the model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dim",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Dimensionality of embeddings and hidden layers.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_heads",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of heads used for multi-head attention",
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_report",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of training steps between loss reporting.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps_per_eval",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of training steps between validations.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval_test",
|
|
||||||
action="store_true",
|
|
||||||
help="Evaluate on the test set after training",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args, device="mps" if args.gpu else "cpu")
|
|
Loading…
Reference in New Issue
Block a user