mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 01:06:43 +08:00
Data parallel helper (#1407)
This commit is contained in:
parent
8d68a3e805
commit
914409fef9
@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void simple_sum(
|
||||||
|
void* input,
|
||||||
|
void* accumulator,
|
||||||
|
int* len,
|
||||||
|
MPI_Datatype* datatype) {
|
||||||
|
T* in = (T*)input;
|
||||||
|
T* acc = (T*)accumulator;
|
||||||
|
int N = *len;
|
||||||
|
|
||||||
|
while (N-- > 0) {
|
||||||
|
*acc += *in;
|
||||||
|
acc++;
|
||||||
|
in++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
|
||||||
struct MPIWrapper {
|
struct MPIWrapper {
|
||||||
MPIWrapper() {
|
MPIWrapper() {
|
||||||
|
initialized_ = false;
|
||||||
|
|
||||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||||
if (libmpi_handle_ == nullptr) {
|
if (libmpi_handle_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
@ -50,6 +71,9 @@ struct MPIWrapper {
|
|||||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||||
LOAD_SYMBOL(MPI_Send, send);
|
LOAD_SYMBOL(MPI_Send, send);
|
||||||
LOAD_SYMBOL(MPI_Recv, recv);
|
LOAD_SYMBOL(MPI_Recv, recv);
|
||||||
|
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
||||||
|
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
||||||
|
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
||||||
|
|
||||||
// Objects
|
// Objects
|
||||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||||
@ -79,7 +103,24 @@ struct MPIWrapper {
|
|||||||
if (!is_available()) {
|
if (!is_available()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return init(nullptr, nullptr) == MPI_SUCCESS;
|
bool success = init(nullptr, nullptr) == MPI_SUCCESS;
|
||||||
|
|
||||||
|
// Initialize custom types and ops
|
||||||
|
if (success && !initialized_) {
|
||||||
|
// Custom float16 dtypes
|
||||||
|
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
|
||||||
|
mpi_type_commit(&mpi_float16_);
|
||||||
|
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
||||||
|
mpi_type_commit(&mpi_bfloat16_);
|
||||||
|
|
||||||
|
// Custom sum ops
|
||||||
|
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
||||||
|
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
||||||
|
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize_safe() {
|
void finalize_safe() {
|
||||||
@ -117,14 +158,22 @@ struct MPIWrapper {
|
|||||||
case complex64:
|
case complex64:
|
||||||
return mpi_complex_;
|
return mpi_complex_;
|
||||||
case float16:
|
case float16:
|
||||||
|
return mpi_float16_;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
throw std::runtime_error("MPI doesn't support 16-bit floats");
|
return mpi_bfloat16_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Op op_sum() {
|
MPI_Op op_sum(const array& arr) {
|
||||||
|
switch (arr.dtype()) {
|
||||||
|
case float16:
|
||||||
|
return op_sum_f16_;
|
||||||
|
case bfloat16:
|
||||||
|
return op_sum_bf16_;
|
||||||
|
default:
|
||||||
return op_sum_;
|
return op_sum_;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void* libmpi_handle_;
|
void* libmpi_handle_;
|
||||||
|
|
||||||
@ -152,6 +201,8 @@ struct MPIWrapper {
|
|||||||
|
|
||||||
// Ops
|
// Ops
|
||||||
MPI_Op op_sum_;
|
MPI_Op op_sum_;
|
||||||
|
MPI_Op op_sum_f16_;
|
||||||
|
MPI_Op op_sum_bf16_;
|
||||||
|
|
||||||
// Datatypes
|
// Datatypes
|
||||||
MPI_Datatype mpi_bool_;
|
MPI_Datatype mpi_bool_;
|
||||||
@ -165,6 +216,16 @@ struct MPIWrapper {
|
|||||||
MPI_Datatype mpi_uint64_;
|
MPI_Datatype mpi_uint64_;
|
||||||
MPI_Datatype mpi_float_;
|
MPI_Datatype mpi_float_;
|
||||||
MPI_Datatype mpi_complex_;
|
MPI_Datatype mpi_complex_;
|
||||||
|
MPI_Datatype mpi_float16_;
|
||||||
|
MPI_Datatype mpi_bfloat16_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool initialized_;
|
||||||
|
|
||||||
|
// Private API
|
||||||
|
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
|
||||||
|
int (*mpi_type_commit)(MPI_Datatype*);
|
||||||
|
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
|
||||||
};
|
};
|
||||||
|
|
||||||
MPIWrapper& mpi() {
|
MPIWrapper& mpi() {
|
||||||
@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) {
|
|||||||
output.data<void>(),
|
output.data<void>(),
|
||||||
input.size(),
|
input.size(),
|
||||||
mpi().datatype(input),
|
mpi().datatype(input),
|
||||||
mpi().op_sum(),
|
mpi().op_sum(input),
|
||||||
to_comm(group));
|
to_comm(group));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from functools import wraps
|
from functools import reduce, wraps
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from ..utils import tree_flatten, tree_map, tree_unflatten
|
||||||
from .layers.base import Module
|
from .layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None):
|
|||||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||||
|
|
||||||
return wrapped_checkpointed_fn
|
return wrapped_checkpointed_fn
|
||||||
|
|
||||||
|
|
||||||
|
def average_gradients(
|
||||||
|
gradients: Any,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
|
):
|
||||||
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
|
This helper enables concatenating several gradients of small arrays to one
|
||||||
|
big all reduce call for better networking performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients (Any): The Python tree containing the gradients (it should
|
||||||
|
have the same structure across processes)
|
||||||
|
group (Optional[mlx.core.distributed.Group]): The group of processes to
|
||||||
|
average the gradients. If set to ``None`` the global group is used.
|
||||||
|
Default: ``None``.
|
||||||
|
all_reduce_size (int): Group arrays until their size in bytes exceeds
|
||||||
|
this number. Perform one communication step per group of arrays. If
|
||||||
|
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
|
||||||
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
|
type before performing the communication. Typically cast to a
|
||||||
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
|
"""
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
N = group.size()
|
||||||
|
|
||||||
|
if N == 1:
|
||||||
|
return gradients
|
||||||
|
|
||||||
|
def _average(x):
|
||||||
|
dt = x.dtype
|
||||||
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
|
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||||
|
|
||||||
|
if all_reduce_size <= 0:
|
||||||
|
return tree_map(_average, gradients)
|
||||||
|
|
||||||
|
else:
|
||||||
|
flat_grads = tree_flatten(gradients)
|
||||||
|
if len(flat_grads) == 0:
|
||||||
|
return gradients
|
||||||
|
|
||||||
|
# Extract some info for the gradient
|
||||||
|
keys = [k for k, _ in flat_grads]
|
||||||
|
shapes = [v.shape for _, v in flat_grads]
|
||||||
|
sizes = [v.size for _, v in flat_grads]
|
||||||
|
dtypes = [v.dtype for _, v in flat_grads]
|
||||||
|
|
||||||
|
# We can't group them if they have mixed types
|
||||||
|
if not all(dt == dtypes[0] for dt in dtypes):
|
||||||
|
return average_gradients(gradients, group, 0, communication_type)
|
||||||
|
itemsize = (
|
||||||
|
communication_type.size
|
||||||
|
if communication_type is not None
|
||||||
|
else dtypes[0].size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather the gradients in groups that are just above or equal to all_reduce_size
|
||||||
|
grad_groups = []
|
||||||
|
grad_group = []
|
||||||
|
grad_group_size = 0
|
||||||
|
for i in range(len(keys)):
|
||||||
|
grad_group.append(i)
|
||||||
|
grad_group_size += sizes[i] * itemsize
|
||||||
|
if grad_group_size >= all_reduce_size:
|
||||||
|
grad_groups.append(grad_group)
|
||||||
|
grad_group = []
|
||||||
|
grad_group_size = 0
|
||||||
|
if grad_group:
|
||||||
|
grad_groups.append(grad_group)
|
||||||
|
grad_group = []
|
||||||
|
|
||||||
|
# Concatenate-reduce-split
|
||||||
|
new_flat_grads = []
|
||||||
|
for grad_group in grad_groups:
|
||||||
|
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
|
||||||
|
big_grad = mx.concatenate(
|
||||||
|
[flat_grads[i][1].reshape(-1) for i in grad_group]
|
||||||
|
)
|
||||||
|
big_grad = _average(big_grad)
|
||||||
|
big_grad = mx.split(big_grad, indices[1:-1])
|
||||||
|
new_flat_grads.extend(
|
||||||
|
(keys[j], big_grad[i].reshape(shapes[j]))
|
||||||
|
for i, j in enumerate(grad_group)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tree_unflatten(new_flat_grads)
|
||||||
|
@ -4,6 +4,7 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
|
|
||||||
|
|
||||||
class TestDistributed(mlx_tests.MLXTestCase):
|
class TestDistributed(mlx_tests.MLXTestCase):
|
||||||
@ -110,6 +111,59 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||||
|
|
||||||
|
def test_average_gradients(self):
|
||||||
|
original_all_sum = mx.distributed.all_sum
|
||||||
|
n_calls = 0
|
||||||
|
xtype = None
|
||||||
|
|
||||||
|
def new_all_sum(x, **kwargs):
|
||||||
|
nonlocal n_calls
|
||||||
|
nonlocal xtype
|
||||||
|
|
||||||
|
n_calls += 1
|
||||||
|
if xtype is not None:
|
||||||
|
self.assertEqual(xtype, x.dtype)
|
||||||
|
|
||||||
|
return original_all_sum(x, **kwargs)
|
||||||
|
|
||||||
|
mx.distributed.all_sum = new_all_sum
|
||||||
|
|
||||||
|
try:
|
||||||
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
|
new_grads = average_gradients(grads)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 10)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
xtype = mx.float16
|
||||||
|
new_grads = average_gradients(
|
||||||
|
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
||||||
|
)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user