mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Add grad checkpointing and PE in the transformer example (#387)
* Add grad checkpointing and PE in the transformer example * Remove other frameworks from LM example * Remove the other frameworks from MNIST example * Improve the transformer LM example * Fix black and change LR
This commit is contained in:

committed by
GitHub

parent
ec14583c2a
commit
e9b32747b4
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
@@ -12,16 +12,28 @@ from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
class TransformerLM(nn.Module):
|
||||
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
num_layers: int,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
checkpoint: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
|
||||
self.pe = nn.SinusoidalPositionalEncoding(dims)
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
|
||||
)
|
||||
self.out_proj = nn.Linear(dims, vocab_size)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
L = x.shape[1]
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
|
||||
x = self.embedding(x)
|
||||
x = x + self.pe(mx.arange(L))
|
||||
x = self.transformer(x, mask)
|
||||
return self.out_proj(x)
|
||||
|
||||
@@ -67,14 +79,18 @@ def main(args):
|
||||
vocab, train, valid, test = datasets.load_dataset(args.dataset)
|
||||
|
||||
# Initialize model:
|
||||
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
|
||||
model = TransformerLM(
|
||||
len(vocab), args.num_blocks, args.dim, args.num_heads, args.checkpoint
|
||||
)
|
||||
mx.eval(model.parameters())
|
||||
nparams = sum(
|
||||
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||
)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||
|
||||
def eval_fn(model, dataset):
|
||||
@@ -93,7 +109,9 @@ def main(args):
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
model.update(optimizer.apply_gradients(grads, model))
|
||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||
optimizer.update(model, grads)
|
||||
del grads
|
||||
mx.eval(loss, model.parameters())
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
@@ -156,12 +174,21 @@ if __name__ == "__main__":
|
||||
default=16,
|
||||
help="Number of heads used for multi-head attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint", action="store_true", help="Perform gradient checkpointing"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
||||
"--learning_rate", type=float, default=3e-4, help="SGD learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay", type=float, default=1e-5, help="Set the weight decay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup", type=int, default=200, help="LR linear warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_report",
|
||||
|
Reference in New Issue
Block a user