use more standard window strategy (#1287)

This commit is contained in:
Awni Hannun 2025-02-19 06:22:51 -08:00 committed by GitHub
parent 96bf37008e
commit 1cbf5cdac7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,6 @@ import datasets
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
samples = dataset.size // window_size
dataset = dataset[: samples * window_size]
return mx.array(dataset.reshape(samples, -1))
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
inputs = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
perm = mx.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
yield inputs[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
@ -84,45 +78,42 @@ def main(args):
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
def loss_fn(model, inputs, reduction="mean"):
x, y = inputs[..., :-1], inputs[..., 1:]
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
return nn.losses.cross_entropy(logits, y, reduction=reduction)
optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay
)
def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
inputs = to_samples(context_size, dataset)
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = loss_fn(model, bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
for s in range(0, inputs.shape[0], batch_size):
losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
loss += losses.item()
return loss / (inputs.size - inputs.shape[0])
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
def step(inputs):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
loss, grads = loss_and_grad_fn(model, inputs)
optimizer.update(model, grads)
return loss
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
for it, inputs in zip(range(args.num_iters), train_iterator):
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
loss = step(inputs, targets)
loss = step(inputs)
mx.eval(state)
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
train_loss = sum(losses) / len(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "