use more standard window strategy

This commit is contained in:
Awni Hannun 2025-02-14 06:53:31 -08:00
parent ec30dc3538
commit 7efc5f8c5e

View File

@ -8,7 +8,6 @@ import datasets
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target window_size = context_size + 1 # include target
samples = tokens - window_size + 1 samples = dataset.size // window_size
X = np.lib.stride_tricks.as_strided( dataset = dataset[: samples * window_size]
dataset, return mx.array(dataset.reshape(samples, -1))
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset): def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset) inputs = to_samples(context_size, dataset)
s = 0 s = 0
while True: while True:
if s == 0: if s == 0:
# Reset permutation: # Reset permutation:
perm = np.random.permutation(inputs.shape[0]) perm = mx.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size] ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids] yield inputs[ids]
s += batch_size s += batch_size
if s >= inputs.shape[0]: if s >= inputs.shape[0]:
s = 0 s = 0
@ -84,45 +78,42 @@ def main(args):
) )
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True): def loss_fn(model, inputs, reduction="mean"):
x, y = inputs[..., :-1], inputs[..., 1:]
logits = model(x) logits = model(x)
losses = nn.losses.cross_entropy(logits, y) return nn.losses.cross_entropy(logits, y, reduction=reduction)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW( optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay learning_rate=args.learning_rate, weight_decay=args.weight_decay
) )
def eval_fn(dataset): def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset)) inputs = to_samples(context_size, dataset)
loss = 0 loss = 0
for s in range(0, targets.shape[0], batch_size): for s in range(0, inputs.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
bx, by = map(mx.array, (bx, by)) loss += losses.item()
losses = loss_fn(model, bx, by, reduce=False) return loss / (inputs.size - inputs.shape[0])
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets): def step(inputs):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets) loss, grads = loss_and_grad_fn(model, inputs)
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
train_iterator = iterate_batches(batch_size, context_size, train) train_iterator = iterate_batches(batch_size, context_size, train)
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): for it, inputs in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
loss = step(inputs, targets) loss = step(inputs)
mx.eval(state) mx.eval(state)
losses.append(loss.item()) losses.append(loss.item())
if (it + 1) % steps_per_report == 0: if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses) train_loss = sum(losses) / len(losses)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it + 1}: Train loss {train_loss:.3f}, "