mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 05:58:07 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
|
||||
# iterator over test set
|
||||
|
104
cvae/main.py
104
cvae/main.py
@@ -2,14 +2,15 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import dataset
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import model
|
||||
import numpy as np
|
||||
import vae
|
||||
from mlx.utils import tree_flatten
|
||||
from PIL import Image
|
||||
|
||||
@@ -53,44 +54,6 @@ def loss_fn(model, X):
|
||||
return recon_loss + kl_div
|
||||
|
||||
|
||||
def train_epoch(model, data, optimizer, epoch):
|
||||
loss_acc = 0.0
|
||||
throughput_acc = 0.0
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Iterate over training batches
|
||||
for batch_count, batch in enumerate(data):
|
||||
X = mx.array(batch["image"])
|
||||
|
||||
throughput_tic = time.perf_counter()
|
||||
|
||||
# Forward pass + backward pass + update
|
||||
loss, grads = loss_and_grad_fn(model, X)
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Evaluate updated model parameters
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
throughput_toc = time.perf_counter()
|
||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||
loss_acc += loss.item()
|
||||
|
||||
if batch_count > 0 and (batch_count % 10 == 0):
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {epoch:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Batch {batch_count:5d}",
|
||||
]
|
||||
),
|
||||
end="\r",
|
||||
)
|
||||
|
||||
return loss_acc, throughput_acc, batch_count
|
||||
|
||||
|
||||
def reconstruct(model, batch, out_file):
|
||||
# Reconstruct a single batch only
|
||||
images = mx.array(batch["image"])
|
||||
@@ -127,10 +90,10 @@ def main(args):
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load the model
|
||||
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||
mx.eval(vae.parameters())
|
||||
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
|
||||
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
|
||||
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
||||
|
||||
optimizer = optim.AdamW(learning_rate=args.lr)
|
||||
@@ -139,19 +102,53 @@ def main(args):
|
||||
train_batch = next(train_iter)
|
||||
test_batch = next(test_iter)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(X):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, X)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for e in range(1, args.epochs + 1):
|
||||
# Reset iterators and stats at the beginning of each epoch
|
||||
train_iter.reset()
|
||||
vae.train()
|
||||
model.train()
|
||||
|
||||
# Train one epoch
|
||||
tic = time.perf_counter()
|
||||
loss_acc, throughput_acc, batch_count = train_epoch(
|
||||
vae, train_iter, optimizer, e
|
||||
)
|
||||
toc = time.perf_counter()
|
||||
loss_acc = 0.0
|
||||
throughput_acc = 0.0
|
||||
|
||||
vae.eval()
|
||||
# Iterate over training batches
|
||||
for batch_count, batch in enumerate(train_iter):
|
||||
X = mx.array(batch["image"])
|
||||
throughput_tic = time.perf_counter()
|
||||
|
||||
# Forward pass + backward pass + update
|
||||
loss = step(X)
|
||||
|
||||
# Evaluate updated model parameters
|
||||
mx.eval(state)
|
||||
|
||||
throughput_toc = time.perf_counter()
|
||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||
loss_acc += loss.item()
|
||||
|
||||
if batch_count > 0 and (batch_count % 10 == 0):
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {e:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Batch {batch_count:5d}",
|
||||
]
|
||||
),
|
||||
end="\r",
|
||||
)
|
||||
toc = time.perf_counter()
|
||||
|
||||
print(
|
||||
" | ".join(
|
||||
@@ -163,14 +160,17 @@ def main(args):
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Reconstruct a batch of training and test images
|
||||
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
|
||||
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
|
||||
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
|
||||
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
|
||||
|
||||
# Generate images
|
||||
generate(vae, save_dir / f"generated_{e:03d}.png")
|
||||
generate(model, save_dir / f"generated_{e:03d}.png")
|
||||
|
||||
vae.save_weights(str(save_dir / "weights.npz"))
|
||||
model.save_weights(str(save_dir / "weights.npz"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.0.9
|
||||
mlx>=0.2
|
||||
mlx-data
|
||||
numpy
|
||||
Pillow
|
||||
|
Reference in New Issue
Block a user