use more standard window strategy (#1287)

This commit is contained in:
Awni Hannun 2025-02-19 06:22:51 -08:00 committed by GitHub
parent 96bf37008e
commit 1cbf5cdac7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,6 @@ import datasets
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target window_size = context_size + 1 # include target
samples = tokens - window_size + 1 samples = dataset.size // window_size
X = np.lib.stride_tricks.as_strided( dataset = dataset[: samples * window_size]
dataset, return mx.array(dataset.reshape(samples, -1))
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset): def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset) inputs = to_samples(context_size, dataset)
s = 0 s = 0
while True: while True:
if s == 0: if s == 0:
# Reset permutation: # Reset permutation:
perm = np.random.permutation(inputs.shape[0]) perm = mx.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size] ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids] yield inputs[ids]
s += batch_size s += batch_size
if s >= inputs.shape[0]: if s >= inputs.shape[0]:
s = 0 s = 0
@ -84,45 +78,42 @@ def main(args):
) )
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True): def loss_fn(model, inputs, reduction="mean"):
x, y = inputs[..., :-1], inputs[..., 1:]
logits = model(x) logits = model(x)
losses = nn.losses.cross_entropy(logits, y) return nn.losses.cross_entropy(logits, y, reduction=reduction)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW( optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay learning_rate=args.learning_rate, weight_decay=args.weight_decay
) )
def eval_fn(dataset): def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset)) inputs = to_samples(context_size, dataset)
loss = 0 loss = 0
for s in range(0, targets.shape[0], batch_size): for s in range(0, inputs.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
bx, by = map(mx.array, (bx, by)) loss += losses.item()
losses = loss_fn(model, bx, by, reduce=False) return loss / (inputs.size - inputs.shape[0])
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets): def step(inputs):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets) loss, grads = loss_and_grad_fn(model, inputs)
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
train_iterator = iterate_batches(batch_size, context_size, train) train_iterator = iterate_batches(batch_size, context_size, train)
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): for it, inputs in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
loss = step(inputs, targets) loss = step(inputs)
mx.eval(state) mx.eval(state)
losses.append(loss.item()) losses.append(loss.item())
if (it + 1) % steps_per_report == 0: if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses) train_loss = sum(losses) / len(losses)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it + 1}: Train loss {train_loss:.3f}, "