mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
|
||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
|
||||
def normalize(x):
|
||||
x = x.astype("float32") / 255.0
|
||||
@@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
|
||||
.image_random_crop("image", 32, 32)
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
|
Reference in New Issue
Block a user