mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Data parallel helper (#1407)
This commit is contained in:

committed by
GitHub

parent
8d68a3e805
commit
914409fef9
@@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from functools import reduce, wraps
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..utils import tree_flatten, tree_map, tree_unflatten
|
||||
from .layers.base import Module
|
||||
|
||||
|
||||
@@ -68,3 +69,93 @@ def checkpoint(module: Module, fn: Optional[Callable] = None):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
return wrapped_checkpointed_fn
|
||||
|
||||
|
||||
def average_gradients(
|
||||
gradients: Any,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
all_reduce_size: int = 32 * 1024**2,
|
||||
communication_type: Optional[mx.Dtype] = None,
|
||||
):
|
||||
"""Average the gradients across the distributed processes in the passed group.
|
||||
|
||||
This helper enables concatenating several gradients of small arrays to one
|
||||
big all reduce call for better networking performance.
|
||||
|
||||
Args:
|
||||
gradients (Any): The Python tree containing the gradients (it should
|
||||
have the same structure across processes)
|
||||
group (Optional[mlx.core.distributed.Group]): The group of processes to
|
||||
average the gradients. If set to ``None`` the global group is used.
|
||||
Default: ``None``.
|
||||
all_reduce_size (int): Group arrays until their size in bytes exceeds
|
||||
this number. Perform one communication step per group of arrays. If
|
||||
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
|
||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||
type before performing the communication. Typically cast to a
|
||||
smaller float to reduce the communication size. Default: ``None``.
|
||||
"""
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
|
||||
if N == 1:
|
||||
return gradients
|
||||
|
||||
def _average(x):
|
||||
dt = x.dtype
|
||||
x = x.astype(communication_type) if communication_type is not None else x
|
||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||
|
||||
if all_reduce_size <= 0:
|
||||
return tree_map(_average, gradients)
|
||||
|
||||
else:
|
||||
flat_grads = tree_flatten(gradients)
|
||||
if len(flat_grads) == 0:
|
||||
return gradients
|
||||
|
||||
# Extract some info for the gradient
|
||||
keys = [k for k, _ in flat_grads]
|
||||
shapes = [v.shape for _, v in flat_grads]
|
||||
sizes = [v.size for _, v in flat_grads]
|
||||
dtypes = [v.dtype for _, v in flat_grads]
|
||||
|
||||
# We can't group them if they have mixed types
|
||||
if not all(dt == dtypes[0] for dt in dtypes):
|
||||
return average_gradients(gradients, group, 0, communication_type)
|
||||
itemsize = (
|
||||
communication_type.size
|
||||
if communication_type is not None
|
||||
else dtypes[0].size
|
||||
)
|
||||
|
||||
# Gather the gradients in groups that are just above or equal to all_reduce_size
|
||||
grad_groups = []
|
||||
grad_group = []
|
||||
grad_group_size = 0
|
||||
for i in range(len(keys)):
|
||||
grad_group.append(i)
|
||||
grad_group_size += sizes[i] * itemsize
|
||||
if grad_group_size >= all_reduce_size:
|
||||
grad_groups.append(grad_group)
|
||||
grad_group = []
|
||||
grad_group_size = 0
|
||||
if grad_group:
|
||||
grad_groups.append(grad_group)
|
||||
grad_group = []
|
||||
|
||||
# Concatenate-reduce-split
|
||||
new_flat_grads = []
|
||||
for grad_group in grad_groups:
|
||||
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
|
||||
big_grad = mx.concatenate(
|
||||
[flat_grads[i][1].reshape(-1) for i in grad_group]
|
||||
)
|
||||
big_grad = _average(big_grad)
|
||||
big_grad = mx.split(big_grad, indices[1:-1])
|
||||
new_flat_grads.extend(
|
||||
(keys[j], big_grad[i].reshape(shapes[j]))
|
||||
for i, j in enumerate(grad_group)
|
||||
)
|
||||
|
||||
return tree_unflatten(new_flat_grads)
|
||||
|
Reference in New Issue
Block a user