2025-03-06 05:33:15 +08:00
|
|
|
import mlx.core as mx
|
2024-02-09 05:00:41 +08:00
|
|
|
import numpy as np
|
2023-12-13 02:01:06 +08:00
|
|
|
from mlx.data.datasets import load_cifar10
|
|
|
|
|
|
|
|
|
|
|
|
def get_cifar10(batch_size, root=None):
|
|
|
|
tr = load_cifar10(root=root)
|
|
|
|
|
2024-02-09 05:00:41 +08:00
|
|
|
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
|
|
|
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
2023-12-13 02:01:06 +08:00
|
|
|
|
2023-12-15 04:09:10 +08:00
|
|
|
def normalize(x):
|
|
|
|
x = x.astype("float32") / 255.0
|
|
|
|
return (x - mean) / std
|
|
|
|
|
2025-03-06 05:33:15 +08:00
|
|
|
group = mx.distributed.init()
|
|
|
|
|
2023-12-13 02:01:06 +08:00
|
|
|
tr_iter = (
|
|
|
|
tr.shuffle()
|
2025-03-06 05:33:15 +08:00
|
|
|
.partition_if(group.size() > 1, group.size(), group.rank())
|
2023-12-13 02:01:06 +08:00
|
|
|
.to_stream()
|
|
|
|
.image_random_h_flip("image", prob=0.5)
|
|
|
|
.pad("image", 0, 4, 4, 0.0)
|
|
|
|
.pad("image", 1, 4, 4, 0.0)
|
|
|
|
.image_random_crop("image", 32, 32)
|
2023-12-15 04:09:10 +08:00
|
|
|
.key_transform("image", normalize)
|
2023-12-13 02:01:06 +08:00
|
|
|
.batch(batch_size)
|
2024-02-09 05:00:41 +08:00
|
|
|
.prefetch(4, 4)
|
2023-12-13 02:01:06 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
test = load_cifar10(root=root, train=False)
|
2025-03-06 05:33:15 +08:00
|
|
|
test_iter = (
|
|
|
|
test.to_stream()
|
|
|
|
.partition_if(group.size() > 1, group.size(), group.rank())
|
|
|
|
.key_transform("image", normalize)
|
|
|
|
.batch(batch_size)
|
|
|
|
)
|
2023-12-13 02:01:06 +08:00
|
|
|
|
2023-12-14 16:05:04 +08:00
|
|
|
return tr_iter, test_iter
|