mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
|
||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
|
||||
def normalize(x):
|
||||
x = x.astype("float32") / 255.0
|
||||
@@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
|
||||
.image_random_crop("image", 32, 32)
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
return loss, acc
|
||||
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inp, tgt):
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||
optimizer.update(model, grads)
|
||||
return loss, acc
|
||||
|
||||
for batch_counter, batch in enumerate(train_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
(loss, acc), grads = train_step_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
loss, acc = step(x, y)
|
||||
mx.eval(state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
|
@@ -1,3 +1,3 @@
|
||||
mlx>=0.0.9
|
||||
mlx>=0.2
|
||||
mlx-data
|
||||
numpy
|
||||
numpy
|
||||
|
Reference in New Issue
Block a user