Data parallel helper (#1407)

This commit is contained in:
Angelos Katharopoulos
2024-09-16 18:17:21 -07:00
committed by GitHub
parent 8d68a3e805
commit 914409fef9
3 changed files with 213 additions and 7 deletions

View File

@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable, Optional
from functools import reduce, wraps
from typing import Any, Callable, Optional
import mlx.core as mx
from ..utils import tree_flatten, tree_map, tree_unflatten
from .layers.base import Module
@@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn
def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
):
"""Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.
Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return gradients
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)
else:
flat_grads = tree_flatten(gradients)
if len(flat_grads) == 0:
return gradients
# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]
# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
return average_gradients(gradients, group, 0, communication_type)
itemsize = (
communication_type.size
if communication_type is not None
else dtypes[0].size
)
# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []
# Concatenate-reduce-split
new_flat_grads = []
for grad_group in grad_groups:
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(-1) for i in grad_group]
)
big_grad = _average(big_grad)
big_grad = mx.split(big_grad, indices[1:-1])
new_flat_grads.extend(
(keys[j], big_grad[i].reshape(shapes[j]))
for i, j in enumerate(grad_group)
)
return tree_unflatten(new_flat_grads)