Add grad checkpointing and PE in the transformer example (#387)

* Add grad checkpointing and PE in the transformer example

* Remove other frameworks from LM example

* Remove the other frameworks from MNIST example

* Improve the transformer LM example

* Fix black and change LR
This commit is contained in:
Angelos Katharopoulos
2024-02-01 13:04:03 -08:00
committed by GitHub
parent ec14583c2a
commit e9b32747b4
8 changed files with 36 additions and 946 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import time
@@ -12,16 +12,28 @@ from mlx.utils import tree_flatten
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
def __init__(
self,
vocab_size: int,
num_layers: int,
dims: int,
num_heads: int,
checkpoint: bool,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
self.pe = nn.SinusoidalPositionalEncoding(dims)
self.transformer = nn.TransformerEncoder(
num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
L = x.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
x = self.embedding(x)
x = x + self.pe(mx.arange(L))
x = self.transformer(x, mask)
return self.out_proj(x)
@@ -67,14 +79,18 @@ def main(args):
vocab, train, valid, test = datasets.load_dataset(args.dataset)
# Initialize model:
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
model = TransformerLM(
len(vocab), args.num_blocks, args.dim, args.num_heads, args.checkpoint
)
mx.eval(model.parameters())
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
optimizer = optim.SGD(learning_rate=args.learning_rate)
optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay
)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(model, dataset):
@@ -93,7 +109,9 @@ def main(args):
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model))
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
optimizer.update(model, grads)
del grads
mx.eval(loss, model.parameters())
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
@@ -156,12 +174,21 @@ if __name__ == "__main__":
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument(
"--checkpoint", action="store_true", help="Perform gradient checkpointing"
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
"--learning_rate", type=float, default=3e-4, help="SGD learning rate."
)
parser.add_argument(
"--weight_decay", type=float, default=1e-5, help="Set the weight decay"
)
parser.add_argument(
"--lr_warmup", type=int, default=200, help="LR linear warmup iterations"
)
parser.add_argument(
"--steps_per_report",