mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Distributed support cifar (#1301)
This commit is contained in:
committed by
GitHub
parent
f621218ff5
commit
e7267d30f8
@@ -1,3 +1,4 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
|
||||
x = x.astype("float32") / 255.0
|
||||
return (x - mean) / std
|
||||
|
||||
group = mx.distributed.init()
|
||||
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.to_stream()
|
||||
.image_random_h_flip("image", prob=0.5)
|
||||
.pad("image", 0, 4, 4, 0.0)
|
||||
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
return tr_iter, test_iter
|
||||
|
||||
Reference in New Issue
Block a user