diff --git a/transformer_lm/main.py b/transformer_lm/main.py index dc725cbe..7ff5b73f 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -8,7 +8,6 @@ import datasets import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -import numpy as np from mlx.utils import tree_flatten @@ -40,26 +39,21 @@ class TransformerLM(nn.Module): def to_samples(context_size, dataset): - tokens = dataset.size window_size = context_size + 1 # include target - samples = tokens - window_size + 1 - X = np.lib.stride_tricks.as_strided( - dataset, - shape=(samples, window_size), - strides=(dataset.itemsize, dataset.itemsize), - ) - return X[:, :-1], X[:, 1:] + samples = dataset.size // window_size + dataset = dataset[: samples * window_size] + return mx.array(dataset.reshape(samples, -1)) def iterate_batches(batch_size, context_size, dataset): - inputs, targets = to_samples(context_size, dataset) + inputs = to_samples(context_size, dataset) s = 0 while True: if s == 0: # Reset permutation: - perm = np.random.permutation(inputs.shape[0]) + perm = mx.random.permutation(inputs.shape[0]) ids = perm[s : s + batch_size] - yield inputs[ids], targets[ids] + yield inputs[ids] s += batch_size if s >= inputs.shape[0]: s = 0 @@ -84,45 +78,42 @@ def main(args): ) print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") - def loss_fn(model, x, y, reduce=True): + def loss_fn(model, inputs, reduction="mean"): + x, y = inputs[..., :-1], inputs[..., 1:] logits = model(x) - losses = nn.losses.cross_entropy(logits, y) - return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) + return nn.losses.cross_entropy(logits, y, reduction=reduction) optimizer = optim.AdamW( learning_rate=args.learning_rate, weight_decay=args.weight_decay ) def eval_fn(dataset): - inputs, targets = map(mx.array, to_samples(context_size, dataset)) + inputs = to_samples(context_size, dataset) loss = 0 - for s in range(0, targets.shape[0], batch_size): - bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] - bx, by = map(mx.array, (bx, by)) - losses = loss_fn(model, bx, by, reduce=False) - loss += mx.sum(losses).item() - return loss / len(targets) + for s in range(0, inputs.shape[0], batch_size): + losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum") + loss += losses.item() + return loss / (inputs.size - inputs.shape[0]) state = [model.state, optimizer.state] @partial(mx.compile, inputs=state, outputs=state) - def step(inputs, targets): + def step(inputs): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) - loss, grads = loss_and_grad_fn(model, inputs, targets) + loss, grads = loss_and_grad_fn(model, inputs) optimizer.update(model, grads) return loss train_iterator = iterate_batches(batch_size, context_size, train) losses = [] tic = time.perf_counter() - for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): - inputs, targets = map(mx.array, (inputs, targets)) + for it, inputs in zip(range(args.num_iters), train_iterator): optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate - loss = step(inputs, targets) + loss = step(inputs) mx.eval(state) losses.append(loss.item()) if (it + 1) % steps_per_report == 0: - train_loss = np.mean(losses) + train_loss = sum(losses) / len(losses) toc = time.perf_counter() print( f"Iter {it + 1}: Train loss {train_loss:.3f}, "