Data parallel helper (#1407)

This commit is contained in:
Angelos Katharopoulos 2024-09-16 18:17:21 -07:00 committed by GitHub
parent 8d68a3e805
commit 914409fef9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 213 additions and 7 deletions

View File

@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
}
}
template <typename T>
void simple_sum(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc += *in;
acc++;
in++;
}
}
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
struct MPIWrapper {
MPIWrapper() {
initialized_ = false;
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
@ -50,6 +71,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@ -79,7 +103,24 @@ struct MPIWrapper {
if (!is_available()) {
return false;
}
return init(nullptr, nullptr) == MPI_SUCCESS;
bool success = init(nullptr, nullptr) == MPI_SUCCESS;
// Initialize custom types and ops
if (success && !initialized_) {
// Custom float16 dtypes
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
mpi_type_commit(&mpi_float16_);
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);
// Custom sum ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
initialized_ = true;
}
return success;
}
void finalize_safe() {
@ -117,14 +158,22 @@ struct MPIWrapper {
case complex64:
return mpi_complex_;
case float16:
return mpi_float16_;
case bfloat16:
throw std::runtime_error("MPI doesn't support 16-bit floats");
return mpi_bfloat16_;
}
}
MPI_Op op_sum() {
MPI_Op op_sum(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_sum_f16_;
case bfloat16:
return op_sum_bf16_;
default:
return op_sum_;
}
}
void* libmpi_handle_;
@ -152,6 +201,8 @@ struct MPIWrapper {
// Ops
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;
// Datatypes
MPI_Datatype mpi_bool_;
@ -165,6 +216,16 @@ struct MPIWrapper {
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_complex_;
MPI_Datatype mpi_float16_;
MPI_Datatype mpi_bfloat16_;
private:
bool initialized_;
// Private API
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
int (*mpi_type_commit)(MPI_Datatype*);
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
};
MPIWrapper& mpi() {
@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) {
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
mpi().op_sum(input),
to_comm(group));
}

View File

@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable, Optional
from functools import reduce, wraps
from typing import Any, Callable, Optional
import mlx.core as mx
from ..utils import tree_flatten, tree_map, tree_unflatten
from .layers.base import Module
@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn
def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
):
"""Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.
Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return gradients
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)
else:
flat_grads = tree_flatten(gradients)
if len(flat_grads) == 0:
return gradients
# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]
# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
return average_gradients(gradients, group, 0, communication_type)
itemsize = (
communication_type.size
if communication_type is not None
else dtypes[0].size
)
# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []
# Concatenate-reduce-split
new_flat_grads = []
for grad_group in grad_groups:
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(-1) for i in grad_group]
)
big_grad = _average(big_grad)
big_grad = mx.split(big_grad, indices[1:-1])
new_flat_grads.extend(
(keys[j], big_grad[i].reshape(shapes[j]))
for i, j in enumerate(grad_group)
)
return tree_unflatten(new_flat_grads)

View File

@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_tests
from mlx.nn.utils import average_gradients
class TestDistributed(mlx_tests.MLXTestCase):
@ -110,6 +111,59 @@ class TestDistributed(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads, all_reduce_size=2 * 50, communication_type=mx.float16
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
if __name__ == "__main__":
unittest.main()