a few examples

This commit is contained in:
Awni Hannun
2023-11-29 08:17:26 -08:00
parent e31d82d3ed
commit b243c1d8f4
32 changed files with 105181 additions and 2 deletions

14
transformer_lm/README.md Normal file
View File

@@ -0,0 +1,14 @@
# Transformer LM
This is an example of a decoder-only Transformer LM. The only dependency is
MLX.
Run the example on the GPU with:
```
python main.py --gpu
```
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
To run the PyTorch, Jax or TensorFlowexamples install the respective framework.

View File

@@ -0,0 +1,90 @@
import io
import itertools
import numpy as np
import os
from urllib import request
import zipfile
def load_dataset(dataname):
if dataname == "ptb":
return ptb()
elif dataname == "wikitext2":
return wikitext(dataset="2")
else:
return wikitext(dataset="103")
def _load(save_dir, filenames):
# *NB* First file is expected to be the training set
with open(os.path.join(save_dir, filenames[0]), "r") as fid:
vocab = set(t for l in fid.readlines() for t in l.strip().split(" "))
eos = "<eos>"
vocab.add(eos)
vocab = {v: i for i, v in enumerate(vocab)}
def to_array(dataset):
with open(os.path.join(save_dir, dataset), "r") as fid:
lines = (l.strip().split(" ") for l in fid.readlines())
return np.array(
[vocab[w] for line in lines for w in itertools.chain(line, [eos])],
dtype=np.uint32,
)
datasets = [to_array(fn) for fn in filenames]
return vocab, *datasets
def wikitext(dataset="2", save_dir="/tmp"):
"""
Load the WikiText-* language modeling dataset:
https://paperswithcode.com/dataset/penn-treebank
"""
if dataset not in ("2", "103"):
raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')
filenames = ["wiki.train.tokens", "wiki.valid.tokens", "wiki.test.tokens"]
dataname = f"wikitext-{dataset}"
data_dir = os.path.join(save_dir, dataname)
if not os.path.exists(data_dir):
base_url = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
zip_file_url = base_url + dataname + "-v1.zip"
r = request.urlopen(zip_file_url)
with zipfile.ZipFile(io.BytesIO(r.read())) as zf:
zf.extractall(save_dir)
return _load(data_dir, filenames)
def ptb(save_dir="/tmp"):
"""
Load the PTB language modeling dataset:
https://paperswithcode.com/dataset/penn-treebank
"""
filenames = [
"ptb.train.txt",
"ptb.valid.txt",
"ptb.test.txt",
]
def download_and_save(save_dir):
base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/"
for name in filenames:
out_file = os.path.join(save_dir, name)
if not os.path.exists(out_file):
request.urlretrieve(base_url + name, out_file)
save_dir = os.path.join(save_dir, "ptb")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
download_and_save(save_dir)
return _load(save_dir, filenames)
if __name__ == "__main__":
vocab, train, val, test = ptb()
assert len(vocab) == 10000, "PTB: Wrong vocab size"
vocab, train, val, test = wikitext()
assert len(vocab) == 33279, "WikiText: Wrong vocab size"

303
transformer_lm/jax_main.py Normal file
View File

@@ -0,0 +1,303 @@
import functools
import jax
import jax.numpy as jnp
import math
import numpy as np
import time
from collections import namedtuple
import datasets
from tree_utils import tree_flatten
"""
Some TODOs for this model:
- Positional encodings
- Dropout
- Adam optimizer
- Option for bigger datasets (wikitext / librispeech text < c4 < ...)
"""
RuntimeConfig = namedtuple("RuntimeConfig", "num_heads")
def embedding_init(key, num_embeddings, embed_dim):
return jax.random.uniform(
key, (num_embeddings, embed_dim), minval=-1e-1, maxval=1e-1
)
def embedding_apply(params, X):
return params.take(X, axis=0)
def dense_init(key, in_dim, out_dim, bias=True):
k1, k2 = jax.random.split(key)
scale = math.sqrt(1 / in_dim)
params = [jax.random.uniform(k1, (in_dim, out_dim), minval=-scale, maxval=scale)]
if bias:
params.append(jax.random.uniform(k2, (out_dim,), minval=-scale, maxval=scale))
return params
def dense_apply(params, X):
X = X @ params[0]
if len(params) == 2:
X = X + params[1]
return X
def layernorm_init(key, dim):
return [jnp.zeros((dim,)), jnp.ones((dim,))]
def layernorm_apply(params, X, epsilon=1e-6):
means = jnp.mean(X, axis=-1, keepdims=True)
var = jnp.var(X, axis=-1, keepdims=True)
X = (X - means) / jnp.sqrt(var + epsilon)
beta, gamma = params
return gamma * X + beta
def mlpblock_init(key, dim):
k1, k2 = jax.random.split(key)
return {
"dense1": dense_init(k1, dim, 4 * dim),
"dense2": dense_init(k2, 4 * dim, dim),
}
def mlpblock_apply(params, X):
X = dense_apply(params["dense1"], X)
X = jnp.maximum(X, 0)
# TODO dropout option here
return dense_apply(params["dense2"], X)
def selfattention_init(key, dim):
k1, k2, k3, k4 = jax.random.split(key, 4)
return {
"Q": dense_init(k1, dim, dim, bias=False),
"K": dense_init(k2, dim, dim, bias=False),
"V": dense_init(k3, dim, dim, bias=False),
"out": dense_init(k4, dim, dim, bias=False),
}
def selfattention_apply(params, num_heads, X, mask):
queries = dense_apply(params["Q"], X)
keys = dense_apply(params["K"], X)
values = dense_apply(params["V"], X)
B, L, D = queries.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ jnp.transpose(keys, (0, 1, 3, 2))
scores = jax.nn.softmax(scores + mask, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return dense_apply(params["out"], values_hat)
def transformer_init(key, token_set_size, num_blocks, dim):
key, ek = jax.random.split(key)
params = {"embedding": embedding_init(ek, token_set_size, dim)}
transformer_blocks = []
for b in range(num_blocks):
key, k1, k2, k3, k4 = jax.random.split(key, 5)
transformer_blocks.append(
{
"ln1": layernorm_init(k1, dim),
"ln2": layernorm_init(k2, dim),
"selfattention": selfattention_init(k3, dim),
"mlpblock": mlpblock_init(k4, dim),
}
)
params["transformer_blocks"] = transformer_blocks
params["output"] = dense_init(key, dim, token_set_size)
return params
def create_additive_causal_mask(N):
indices = jnp.arange(N)
mask = jnp.reshape(indices, (-1, 1)) < jnp.reshape(indices, (1, -1))
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
mask = mask.astype(jnp.float32) * -1e9
return mask
def transformer_apply(params, static_params, inputs):
mask = create_additive_causal_mask(inputs.shape[1])
X = embedding_apply(params["embedding"], inputs)
for block in params["transformer_blocks"]:
out = layernorm_apply(block["ln1"], X)
out = selfattention_apply(
block["selfattention"], static_params.num_heads, out, mask
)
X = X + out
out = layernorm_apply(block["ln2"], X)
out = mlpblock_apply(block["mlpblock"], out)
X = X + out
return dense_apply(params["output"], X)
@functools.partial(jax.jit, static_argnames=["static_params", "reduce"])
def loss_fn(params, static_params, inputs, targets, reduce=True):
logits = transformer_apply(params, static_params, inputs)
logits = jax.nn.log_softmax(logits, axis=-1)
sample_indices = jnp.arange(targets.shape[0])[:, None]
token_indices = jnp.arange(targets.shape[1])[None, :]
losses = -logits[sample_indices, token_indices, targets]
return jnp.mean(losses) if reduce else losses.mean(-1)
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(key, batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def main(args):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
config = RuntimeConfig(args.num_heads)
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
key, subkey = jax.random.split(jax.random.PRNGKey(args.seed))
params = transformer_init(subkey, len(vocab), args.num_blocks, args.dim)
nparams = sum(x.size for k, x in tree_flatten(params) if "embedding" not in k)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
loss_and_grad_fn = jax.jit(
jax.value_and_grad(loss_fn), static_argnames=["static_params"]
)
update_fn = jax.jit(
functools.partial(jax.tree_map, lambda p, g: p - args.learning_rate * g)
)
def eval_fn(params, dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
losses = loss_fn(params, config, bx, by, reduce=False)
loss += jnp.sum(losses)
return loss / len(targets)
train_iterator = iterate_batches(subkey, batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
loss, grads = loss_and_grad_fn(params, config, inputs, targets)
losses.append(loss.item())
params = update_fn(params, grads)
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(params, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss.item():.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with Jax.")
parser.add_argument(
"--seed", type=int, default=0, help="Seed for numpy and Jax RNGs."
)
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args)

190
transformer_lm/main.py Normal file
View File

@@ -0,0 +1,190 @@
import math
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def main(args):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.load_dataset(args.dataset)
# Initialize model:
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
mx.eval(model.parameters())
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
optimizer = optim.SGD(learning_rate=args.learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(params, dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = self.loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model))
mx.simplify(loss, model.parameters())
mx.eval(loss, model.parameters())
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(params, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--dataset",
type=str,
default="ptb",
choices=["ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.",
)
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)

249
transformer_lm/tf_main.py Normal file
View File

@@ -0,0 +1,249 @@
import math
import time
import numpy as np
import tensorflow as tf
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s + batch_size >= inputs.shape[0]:
s = 0
def create_additive_causal_mask(N):
indices = tf.range(N)
mask = tf.reshape(indices, (-1, 1)) < tf.reshape(indices, (1, -1))
return tf.cast(mask, tf.dtypes.float32) * -1e9
class SelfAttention(tf.keras.layers.Layer):
def __init__(self, num_heads, model_dims, context_size):
super().__init__()
self.Wq = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wk = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wv = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wo = tf.keras.layers.Dense(model_dims, use_bias=False)
self.causal_mask = create_additive_causal_mask(context_size)
self.num_heads = num_heads
self.head_dim = model_dims // num_heads
self.scale = tf.constant(1.0 / math.sqrt(self.head_dim))
def call(self, x):
queries = self.Wq(x)
keys = self.Wk(x)
values = self.Wv(x)
B, L, D = x.shape
queries = tf.transpose(
tf.reshape(queries, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
keys = tf.transpose(
tf.reshape(keys, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
values = tf.transpose(
tf.reshape(values, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
scores = (self.scale * queries) @ tf.transpose(keys, (0, 1, 3, 2))
scores = tf.nn.softmax(scores + self.causal_mask, axis=-1)
values = tf.matmul(scores, values)
values_hat = tf.reshape(tf.transpose(values, perm=(0, 2, 1, 3)), (B, L, -1))
return self.Wo(values_hat)
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, num_heads, model_dims, context_size):
super().__init__()
self._ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self._self_attn = SelfAttention(num_heads, model_dims, context_size)
self._ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self._mlp = tf.keras.Sequential(
[
tf.keras.layers.Dense(4 * model_dims, activation="relu"),
tf.keras.layers.Dense(model_dims),
]
)
def call(self, x):
x = x + self._self_attn(self._ln1(x))
x = x + self._mlp(self._ln2(x))
return x
class TransformerLM(tf.keras.Model):
def __init__(self, vocab_size, num_layers, num_heads, model_dims, context_size):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, model_dims)
self.transformer = tf.keras.Sequential(
[
EncoderLayer(num_heads, model_dims, context_size)
for _ in range(num_layers)
]
)
self.projection = tf.keras.layers.Dense(vocab_size)
def call(self, x):
x = self.embedding(x)
x = self.transformer(x)
x = self.projection(x)
return x
def main(args, device):
with tf.device(device):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
transformer = TransformerLM(
len(vocab), args.num_blocks, args.num_heads, args.dim, context_size
)
transformer.compile(
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=args.learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
transformer.build((batch_size, context_size))
nparams = sum(
np.prod(p.shape) for p in transformer.trainable_weights[1:]
) # [1:] to skip the embedding
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def eval_fn(dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
n_batches = 0
for s in range(0, targets.shape[0], batch_size):
if s + batch_size >= targets.shape[0]:
s = targets.shape[0] - 1 - batch_size
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
[bx, by],
)
loss += transformer.test_on_batch(bx, by)
n_batches += 1
return loss / n_batches
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
[inputs, targets],
)
loss = transformer.train_on_batch(inputs, targets)
losses.append(loss)
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args, device="/GPU:0" if args.gpu else "/CPU:0")

View File

@@ -0,0 +1,197 @@
import math
import time
import numpy as np
import torch
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def create_additive_causal_mask(N, device):
# torch.nn.Transformer.generate_square_subsequent_mask
# gives NaNs with `device="mps"`
indices = torch.arange(N, device=device)
mask = indices.reshape((-1, 1)) < indices.reshape((1, -1))
return mask.to(torch.float32) * -1e9
class TransformerLM(torch.nn.Module):
def __init__(self, vocab_size, num_layers, num_heads, model_dims):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, model_dims)
self.transformer = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
model_dims,
num_heads,
4 * model_dims,
dropout=0.0,
norm_first=True,
batch_first=True,
),
num_layers,
)
self.projection = torch.nn.Linear(model_dims, vocab_size)
def forward(self, x):
mask = create_additive_causal_mask(x.shape[1], device=x.device)
x = self.embedding(x)
x = self.transformer(x, mask=mask)
x = self.projection(x)
return x
def main(args, device):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
transformer = TransformerLM(len(vocab), args.num_blocks, args.num_heads, args.dim)
transformer = transformer.to(device)
optim = torch.optim.SGD(transformer.parameters(), lr=args.learning_rate, momentum=0)
nparams = sum(
p.numel() for n, p in transformer.named_parameters() if "embedding" not in n
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
@torch.no_grad()
def eval_fn(dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(lambda x: torch.from_numpy(x.astype(int)).to(device), [bx, by])
logits = transformer(bx)
losses = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), by.flatten(), reduction="none"
)
losses = losses.view(-1, by.shape[-1]).mean(-1)
loss += losses.sum().item()
return loss / len(targets)
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(
lambda x: torch.from_numpy(x.astype(int)).to(device), [inputs, targets]
)
optim.zero_grad()
logits = transformer(inputs)
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), targets.flatten()
)
loss.backward()
optim.step()
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args, device="mps" if args.gpu else "cpu")