Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun
2024-02-08 13:00:41 -08:00
committed by GitHub
parent da7adae5ec
commit f45a1ab83c
17 changed files with 164 additions and 118 deletions

View File

@@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
.prefetch(4, 4)
)
# iterator over test set

View File

@@ -2,14 +2,15 @@
import argparse
import time
from functools import partial
from pathlib import Path
import dataset
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import model
import numpy as np
import vae
from mlx.utils import tree_flatten
from PIL import Image
@@ -53,44 +54,6 @@ def loss_fn(model, X):
return recon_loss + kl_div
def train_epoch(model, data, optimizer, epoch):
loss_acc = 0.0
throughput_acc = 0.0
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Iterate over training batches
for batch_count, batch in enumerate(data):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
# Evaluate updated model parameters
mx.eval(model.parameters(), optimizer.state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {epoch:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
return loss_acc, throughput_acc, batch_count
def reconstruct(model, batch, out_file):
# Reconstruct a single batch only
images = mx.array(batch["image"])
@@ -127,10 +90,10 @@ def main(args):
save_dir.mkdir(parents=True, exist_ok=True)
# Load the model
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(vae.parameters())
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(model.parameters())
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
optimizer = optim.AdamW(learning_rate=args.lr)
@@ -139,19 +102,53 @@ def main(args):
train_batch = next(train_iter)
test_batch = next(test_iter)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(X):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
return loss
for e in range(1, args.epochs + 1):
# Reset iterators and stats at the beginning of each epoch
train_iter.reset()
vae.train()
model.train()
# Train one epoch
tic = time.perf_counter()
loss_acc, throughput_acc, batch_count = train_epoch(
vae, train_iter, optimizer, e
)
toc = time.perf_counter()
loss_acc = 0.0
throughput_acc = 0.0
vae.eval()
# Iterate over training batches
for batch_count, batch in enumerate(train_iter):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss = step(X)
# Evaluate updated model parameters
mx.eval(state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {e:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
toc = time.perf_counter()
print(
" | ".join(
@@ -163,14 +160,17 @@ def main(args):
]
)
)
model.eval()
# Reconstruct a batch of training and test images
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
# Generate images
generate(vae, save_dir / f"generated_{e:03d}.png")
generate(model, save_dir / f"generated_{e:03d}.png")
vae.save_weights(str(save_dir / "weights.npz"))
model.save_weights(str(save_dir / "weights.npz"))
if __name__ == "__main__":

View File

@@ -1,4 +1,4 @@
mlx>=0.0.9
mlx>=0.2
mlx-data
numpy
Pillow