mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 00:36:49 +08:00
Data parallel helper (#1407)
This commit is contained in:
parent
8d68a3e805
commit
914409fef9
@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void simple_sum(
|
||||
void* input,
|
||||
void* accumulator,
|
||||
int* len,
|
||||
MPI_Datatype* datatype) {
|
||||
T* in = (T*)input;
|
||||
T* acc = (T*)accumulator;
|
||||
int N = *len;
|
||||
|
||||
while (N-- > 0) {
|
||||
*acc += *in;
|
||||
acc++;
|
||||
in++;
|
||||
}
|
||||
}
|
||||
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
initialized_ = false;
|
||||
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
@ -50,6 +71,9 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
LOAD_SYMBOL(MPI_Send, send);
|
||||
LOAD_SYMBOL(MPI_Recv, recv);
|
||||
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
||||
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
||||
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
@ -79,7 +103,24 @@ struct MPIWrapper {
|
||||
if (!is_available()) {
|
||||
return false;
|
||||
}
|
||||
return init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
bool success = init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
|
||||
// Initialize custom types and ops
|
||||
if (success && !initialized_) {
|
||||
// Custom float16 dtypes
|
||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
|
||||
mpi_type_commit(&mpi_float16_);
|
||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
||||
mpi_type_commit(&mpi_bfloat16_);
|
||||
|
||||
// Custom sum ops
|
||||
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
||||
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
||||
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
void finalize_safe() {
|
||||
@ -117,13 +158,21 @@ struct MPIWrapper {
|
||||
case complex64:
|
||||
return mpi_complex_;
|
||||
case float16:
|
||||
return mpi_float16_;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("MPI doesn't support 16-bit floats");
|
||||
return mpi_bfloat16_;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_sum() {
|
||||
return op_sum_;
|
||||
MPI_Op op_sum(const array& arr) {
|
||||
switch (arr.dtype()) {
|
||||
case float16:
|
||||
return op_sum_f16_;
|
||||
case bfloat16:
|
||||
return op_sum_bf16_;
|
||||
default:
|
||||
return op_sum_;
|
||||
}
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
@ -152,6 +201,8 @@ struct MPIWrapper {
|
||||
|
||||
// Ops
|
||||
MPI_Op op_sum_;
|
||||
MPI_Op op_sum_f16_;
|
||||
MPI_Op op_sum_bf16_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
@ -165,6 +216,16 @@ struct MPIWrapper {
|
||||
MPI_Datatype mpi_uint64_;
|
||||
MPI_Datatype mpi_float_;
|
||||
MPI_Datatype mpi_complex_;
|
||||
MPI_Datatype mpi_float16_;
|
||||
MPI_Datatype mpi_bfloat16_;
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
|
||||
// Private API
|
||||
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
|
||||
int (*mpi_type_commit)(MPI_Datatype*);
|
||||
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
|
||||
};
|
||||
|
||||
MPIWrapper& mpi() {
|
||||
@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) {
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(),
|
||||
mpi().op_sum(input),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from functools import reduce, wraps
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..utils import tree_flatten, tree_map, tree_unflatten
|
||||
from .layers.base import Module
|
||||
|
||||
|
||||
@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
return wrapped_checkpointed_fn
|
||||
|
||||
|
||||
def average_gradients(
|
||||
gradients: Any,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
all_reduce_size: int = 32 * 1024**2,
|
||||
communication_type: Optional[mx.Dtype] = None,
|
||||
):
|
||||
"""Average the gradients across the distributed processes in the passed group.
|
||||
|
||||
This helper enables concatenating several gradients of small arrays to one
|
||||
big all reduce call for better networking performance.
|
||||
|
||||
Args:
|
||||
gradients (Any): The Python tree containing the gradients (it should
|
||||
have the same structure across processes)
|
||||
group (Optional[mlx.core.distributed.Group]): The group of processes to
|
||||
average the gradients. If set to ``None`` the global group is used.
|
||||
Default: ``None``.
|
||||
all_reduce_size (int): Group arrays until their size in bytes exceeds
|
||||
this number. Perform one communication step per group of arrays. If
|
||||
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
|
||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||
type before performing the communication. Typically cast to a
|
||||
smaller float to reduce the communication size. Default: ``None``.
|
||||
"""
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
|
||||
if N == 1:
|
||||
return gradients
|
||||
|
||||
def _average(x):
|
||||
dt = x.dtype
|
||||
x = x.astype(communication_type) if communication_type is not None else x
|
||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||
|
||||
if all_reduce_size <= 0:
|
||||
return tree_map(_average, gradients)
|
||||
|
||||
else:
|
||||
flat_grads = tree_flatten(gradients)
|
||||
if len(flat_grads) == 0:
|
||||
return gradients
|
||||
|
||||
# Extract some info for the gradient
|
||||
keys = [k for k, _ in flat_grads]
|
||||
shapes = [v.shape for _, v in flat_grads]
|
||||
sizes = [v.size for _, v in flat_grads]
|
||||
dtypes = [v.dtype for _, v in flat_grads]
|
||||
|
||||
# We can't group them if they have mixed types
|
||||
if not all(dt == dtypes[0] for dt in dtypes):
|
||||
return average_gradients(gradients, group, 0, communication_type)
|
||||
itemsize = (
|
||||
communication_type.size
|
||||
if communication_type is not None
|
||||
else dtypes[0].size
|
||||
)
|
||||
|
||||
# Gather the gradients in groups that are just above or equal to all_reduce_size
|
||||
grad_groups = []
|
||||
grad_group = []
|
||||
grad_group_size = 0
|
||||
for i in range(len(keys)):
|
||||
grad_group.append(i)
|
||||
grad_group_size += sizes[i] * itemsize
|
||||
if grad_group_size >= all_reduce_size:
|
||||
grad_groups.append(grad_group)
|
||||
grad_group = []
|
||||
grad_group_size = 0
|
||||
if grad_group:
|
||||
grad_groups.append(grad_group)
|
||||
grad_group = []
|
||||
|
||||
# Concatenate-reduce-split
|
||||
new_flat_grads = []
|
||||
for grad_group in grad_groups:
|
||||
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
|
||||
big_grad = mx.concatenate(
|
||||
[flat_grads[i][1].reshape(-1) for i in grad_group]
|
||||
)
|
||||
big_grad = _average(big_grad)
|
||||
big_grad = mx.split(big_grad, indices[1:-1])
|
||||
new_flat_grads.extend(
|
||||
(keys[j], big_grad[i].reshape(shapes[j]))
|
||||
for i, j in enumerate(grad_group)
|
||||
)
|
||||
|
||||
return tree_unflatten(new_flat_grads)
|
||||
|
@ -4,6 +4,7 @@ import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
from mlx.nn.utils import average_gradients
|
||||
|
||||
|
||||
class TestDistributed(mlx_tests.MLXTestCase):
|
||||
@ -110,6 +111,59 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
xtype = None
|
||||
|
||||
def new_all_sum(x, **kwargs):
|
||||
nonlocal n_calls
|
||||
nonlocal xtype
|
||||
|
||||
n_calls += 1
|
||||
if xtype is not None:
|
||||
self.assertEqual(xtype, x.dtype)
|
||||
|
||||
return original_all_sum(x, **kwargs)
|
||||
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 10)
|
||||
|
||||
n_calls = 0
|
||||
xtype = mx.float16
|
||||
new_grads = average_gradients(
|
||||
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user