Distributed support cifar (#1301)

This commit is contained in:
Angelos Katharopoulos
2025-03-05 13:33:15 -08:00
committed by GitHub
parent f621218ff5
commit e7267d30f8
3 changed files with 84 additions and 32 deletions

View File

@@ -1,3 +1,4 @@
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0
return (x - mean) / std
group = mx.distributed.init()
tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
)
test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter