added CIFAR10 + ResNet example

This commit is contained in:
Sarthak Yadav
2023-12-12 19:01:06 +01:00
parent 9ec14f7d47
commit ea9a4f878a
5 changed files with 310 additions and 0 deletions

39
cifar/dataset.py Normal file
View File

@@ -0,0 +1,39 @@
import mlx.core as mx
from mlx.data.datasets import load_cifar10
import math
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
num_tr_samples = tr.size()
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
tr_iter = (
tr.shuffle()
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
.pad("image", 1, 4, 4, 0.0)
.image_random_crop("image", 32, 32)
.key_transform("image", lambda x: (x.astype("float32") / 255.0))
.key_transform("image", lambda x: (x - mean) / std)
.batch(batch_size)
)
test = load_cifar10(root=root, train=False)
num_test_samples = test.size()
test_iter = (
test.to_stream()
.key_transform("image", lambda x: (x.astype("float32") / 255.0))
.key_transform("image", lambda x: (x - mean) / std)
.batch(batch_size)
)
num_tr_steps_per_epoch = num_tr_samples // batch_size
num_test_steps_per_epoch = num_test_samples // batch_size
return tr_iter, test_iter, num_tr_steps_per_epoch, num_test_steps_per_epoch