mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Distributed support cifar (#1301)
This commit is contained in:
parent
f621218ff5
commit
e7267d30f8
@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
|||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||||
We intend to update this example once these features are added.
|
We intend to update this example once these features are added.
|
||||||
|
|
||||||
|
## Distributed training
|
||||||
|
|
||||||
|
The example also supports distributed data parallel training. You can launch a
|
||||||
|
distributed training as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ cat >hostfile.json
|
||||||
|
[
|
||||||
|
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
||||||
|
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
||||||
|
]
|
||||||
|
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
||||||
|
```
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.data.datasets import load_cifar10
|
from mlx.data.datasets import load_cifar10
|
||||||
|
|
||||||
@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
|
|||||||
x = x.astype("float32") / 255.0
|
x = x.astype("float32") / 255.0
|
||||||
return (x - mean) / std
|
return (x - mean) / std
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
|
||||||
tr_iter = (
|
tr_iter = (
|
||||||
tr.shuffle()
|
tr.shuffle()
|
||||||
|
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||||
.to_stream()
|
.to_stream()
|
||||||
.image_random_h_flip("image", prob=0.5)
|
.image_random_h_flip("image", prob=0.5)
|
||||||
.pad("image", 0, 4, 4, 0.0)
|
.pad("image", 0, 4, 4, 0.0)
|
||||||
@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test = load_cifar10(root=root, train=False)
|
test = load_cifar10(root=root, train=False)
|
||||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
test_iter = (
|
||||||
|
test.to_stream()
|
||||||
|
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||||
|
.key_transform("image", normalize)
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
return tr_iter, test_iter
|
return tr_iter, test_iter
|
||||||
|
@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
|
|||||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
|
def print_zero(group, *args, **kwargs):
|
||||||
|
if group.rank() != 0:
|
||||||
|
return
|
||||||
|
flush = kwargs.pop("flush", True)
|
||||||
|
print(*args, **kwargs, flush=flush)
|
||||||
|
|
||||||
|
|
||||||
def eval_fn(model, inp, tgt):
|
def eval_fn(model, inp, tgt):
|
||||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
losses = []
|
world = mx.distributed.init()
|
||||||
accs = []
|
losses = 0
|
||||||
samples_per_sec = []
|
accuracies = 0
|
||||||
|
samples_per_sec = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
def average_stats(stats, count):
|
||||||
|
if world.size() == 1:
|
||||||
|
return [s / count for s in stats]
|
||||||
|
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
stats = mx.distributed.all_sum(mx.array(stats))
|
||||||
|
count = mx.distributed.all_sum(count)
|
||||||
|
return (stats / count).tolist()
|
||||||
|
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
def step(inp, tgt):
|
def step(inp, tgt):
|
||||||
train_step_fn = nn.value_and_grad(model, train_step)
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||||
|
grads = nn.utils.average_gradients(grads)
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
loss, acc = step(x, y)
|
loss, acc = step(x, y)
|
||||||
mx.eval(state)
|
mx.eval(loss, acc, state)
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
loss = loss.item()
|
losses += loss.item()
|
||||||
acc = acc.item()
|
accuracies += acc.item()
|
||||||
losses.append(loss)
|
samples_per_sec += x.shape[0] / (toc - tic)
|
||||||
accs.append(acc)
|
count += 1
|
||||||
throughput = x.shape[0] / (toc - tic)
|
|
||||||
samples_per_sec.append(throughput)
|
|
||||||
if batch_counter % 10 == 0:
|
if batch_counter % 10 == 0:
|
||||||
print(
|
l, a, s = average_stats(
|
||||||
|
[losses, accuracies, world.size() * samples_per_sec],
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
print_zero(
|
||||||
|
world,
|
||||||
" | ".join(
|
" | ".join(
|
||||||
(
|
(
|
||||||
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
f"Train loss {loss:.3f}",
|
f"Train loss {l:.3f}",
|
||||||
f"Train acc {acc:.3f}",
|
f"Train acc {a:.3f}",
|
||||||
f"Throughput: {throughput:.2f} images/second",
|
f"Throughput: {s:.2f} images/second",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
mean_tr_loss = mx.mean(mx.array(losses))
|
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
||||||
mean_tr_acc = mx.mean(mx.array(accs))
|
|
||||||
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
|
||||||
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
|
||||||
|
|
||||||
|
|
||||||
def test_epoch(model, test_iter, epoch):
|
def test_epoch(model, test_iter, epoch):
|
||||||
accs = []
|
accuracies = 0
|
||||||
|
count = 0
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["image"])
|
x = mx.array(batch["image"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
acc = eval_fn(model, x, y)
|
acc = eval_fn(model, x, y)
|
||||||
acc_value = acc.item()
|
accuracies += acc.item()
|
||||||
accs.append(acc_value)
|
count += 1
|
||||||
mean_acc = mx.mean(mx.array(accs))
|
|
||||||
return mean_acc
|
with mx.stream(mx.cpu):
|
||||||
|
accuracies = mx.distributed.all_sum(accuracies)
|
||||||
|
count = mx.distributed.all_sum(count)
|
||||||
|
return (accuracies / count).item()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
# Initialize the distributed group and report the nodes that showed up
|
||||||
|
world = mx.distributed.init()
|
||||||
|
if world.size() > 1:
|
||||||
|
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
||||||
|
|
||||||
model = getattr(resnet, args.arch)()
|
model = getattr(resnet, args.arch)()
|
||||||
|
|
||||||
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
|
||||||
|
|
||||||
optimizer = optim.Adam(learning_rate=args.lr)
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
|
|
||||||
train_data, test_data = get_cifar10(args.batch_size)
|
train_data, test_data = get_cifar10(args.batch_size)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
print(
|
print_zero(
|
||||||
|
world,
|
||||||
" | ".join(
|
" | ".join(
|
||||||
(
|
(
|
||||||
f"Epoch: {epoch}",
|
f"Epoch: {epoch}",
|
||||||
f"avg. Train loss {tr_loss.item():.3f}",
|
f"avg. Train loss {tr_loss:.3f}",
|
||||||
f"avg. Train acc {tr_acc.item():.3f}",
|
f"avg. Train acc {tr_acc:.3f}",
|
||||||
f"Throughput: {throughput.item():.2f} images/sec",
|
f"Throughput: {throughput:.2f} images/sec",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
test_acc = test_epoch(model, test_data, epoch)
|
test_acc = test_epoch(model, test_data, epoch)
|
||||||
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
|
||||||
|
|
||||||
train_data.reset()
|
train_data.reset()
|
||||||
test_data.reset()
|
test_data.reset()
|
||||||
|
Loading…
Reference in New Issue
Block a user