mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
updates + format
This commit is contained in:
@@ -4,13 +4,15 @@ import math
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
|
||||
tr = load_cifar10(root=root)
|
||||
num_tr_samples = tr.size()
|
||||
|
||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
|
||||
def normalize(x):
|
||||
x = x.astype("float32") / 255.0
|
||||
return (x - mean) / std
|
||||
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.to_stream()
|
||||
@@ -18,22 +20,11 @@ def get_cifar10(batch_size, root=None):
|
||||
.pad("image", 0, 4, 4, 0.0)
|
||||
.pad("image", 1, 4, 4, 0.0)
|
||||
.image_random_crop("image", 32, 32)
|
||||
.key_transform("image", lambda x: (x.astype("float32") / 255.0))
|
||||
.key_transform("image", lambda x: (x - mean) / std)
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
num_test_samples = test.size()
|
||||
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.key_transform("image", lambda x: (x.astype("float32") / 255.0))
|
||||
.key_transform("image", lambda x: (x - mean) / std)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
num_tr_steps_per_epoch = num_tr_samples // batch_size
|
||||
num_test_steps_per_epoch = num_test_samples // batch_size
|
||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||
|
||||
return tr_iter, test_iter
|
||||
|
||||
Reference in New Issue
Block a user