mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
use more standard window strategy (#1287)
This commit is contained in:
parent
96bf37008e
commit
1cbf5cdac7
@ -8,7 +8,6 @@ import datasets
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
@ -40,26 +39,21 @@ class TransformerLM(nn.Module):
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
samples = dataset.size // window_size
|
||||
dataset = dataset[: samples * window_size]
|
||||
return mx.array(dataset.reshape(samples, -1))
|
||||
|
||||
|
||||
def iterate_batches(batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
inputs = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
perm = np.random.permutation(inputs.shape[0])
|
||||
perm = mx.random.permutation(inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
yield inputs[ids]
|
||||
s += batch_size
|
||||
if s >= inputs.shape[0]:
|
||||
s = 0
|
||||
@ -84,45 +78,42 @@ def main(args):
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
def loss_fn(model, x, y, reduce=True):
|
||||
def loss_fn(model, inputs, reduction="mean"):
|
||||
x, y = inputs[..., :-1], inputs[..., 1:]
|
||||
logits = model(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
return nn.losses.cross_entropy(logits, y, reduction=reduction)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
inputs = to_samples(context_size, dataset)
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = loss_fn(model, bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
for s in range(0, inputs.shape[0], batch_size):
|
||||
losses = loss_fn(model, inputs[s : s + batch_size], reduction="sum")
|
||||
loss += losses.item()
|
||||
return loss / (inputs.size - inputs.shape[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inputs, targets):
|
||||
def step(inputs):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
||||
loss, grads = loss_and_grad_fn(model, inputs)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
for it, inputs in zip(range(args.num_iters), train_iterator):
|
||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||
loss = step(inputs, targets)
|
||||
loss = step(inputs)
|
||||
mx.eval(state)
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
train_loss = sum(losses) / len(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
|
Loading…
Reference in New Issue
Block a user