mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
162 lines
5.7 KiB
Python
162 lines
5.7 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
from functools import reduce, wraps
|
|
from typing import Any, Callable, Optional
|
|
|
|
import mlx.core as mx
|
|
|
|
from ..utils import tree_flatten, tree_map, tree_unflatten
|
|
from .layers.base import Module
|
|
|
|
|
|
def value_and_grad(model: Module, fn: Callable):
|
|
"""Transform the passed function ``fn`` to a function that computes the
|
|
gradients of ``fn`` wrt the model's trainable parameters and also its
|
|
value.
|
|
|
|
Args:
|
|
model (mlx.nn.Module): The model whose trainable parameters to compute
|
|
gradients for
|
|
fn (Callable): The scalar function to compute gradients for
|
|
|
|
Returns:
|
|
A callable that returns the value of ``fn`` and the gradients wrt the
|
|
trainable parameters of ``model``
|
|
"""
|
|
|
|
def inner_fn(params, *args, **kwargs):
|
|
model.update(params)
|
|
return fn(*args, **kwargs)
|
|
|
|
value_grad_fn = mx.value_and_grad(inner_fn)
|
|
|
|
@wraps(fn)
|
|
def wrapped_value_grad_fn(*args, **kwargs):
|
|
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
|
return value, grad
|
|
|
|
return wrapped_value_grad_fn
|
|
|
|
|
|
def checkpoint(module: Module, fn: Optional[Callable] = None):
|
|
"""Transform the passed callable to one that performs gradient
|
|
checkpointing with respect to the trainable parameters of the module (and
|
|
the callable's inputs).
|
|
|
|
Args:
|
|
module (mlx.nn.Module): The module for whose parameters we will be
|
|
performing gradient checkpointing.
|
|
fn (Callable, optional): The function to checkpoint. If not provided it
|
|
defaults to the provided module.
|
|
|
|
Returns:
|
|
A callable that saves the inputs and outputs during the forward pass
|
|
and recomputes all intermediate states during the backward pass.
|
|
"""
|
|
if fn is None:
|
|
# Capturing module instead of module.__call__ allows someone to
|
|
# monkey-patch __call__ later on and the correct method will be used
|
|
fn = module
|
|
|
|
def inner_fn(params, *args, **kwargs):
|
|
module.update(params)
|
|
return fn(*args, **kwargs)
|
|
|
|
checkpointed_fn = mx.checkpoint(inner_fn)
|
|
|
|
@wraps(fn)
|
|
def wrapped_checkpointed_fn(*args, **kwargs):
|
|
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
|
|
|
return wrapped_checkpointed_fn
|
|
|
|
|
|
def average_gradients(
|
|
gradients: Any,
|
|
group: Optional[mx.distributed.Group] = None,
|
|
all_reduce_size: int = 32 * 1024**2,
|
|
communication_type: Optional[mx.Dtype] = None,
|
|
):
|
|
"""Average the gradients across the distributed processes in the passed group.
|
|
|
|
This helper enables concatenating several gradients of small arrays to one
|
|
big all reduce call for better networking performance.
|
|
|
|
Args:
|
|
gradients (Any): The Python tree containing the gradients (it should
|
|
have the same structure across processes)
|
|
group (Optional[mlx.core.distributed.Group]): The group of processes to
|
|
average the gradients. If set to ``None`` the global group is used.
|
|
Default: ``None``.
|
|
all_reduce_size (int): Group arrays until their size in bytes exceeds
|
|
this number. Perform one communication step per group of arrays. If
|
|
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
|
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
|
type before performing the communication. Typically cast to a
|
|
smaller float to reduce the communication size. Default: ``None``.
|
|
"""
|
|
group = group or mx.distributed.init()
|
|
N = group.size()
|
|
|
|
if N == 1:
|
|
return gradients
|
|
|
|
def _average(x):
|
|
dt = x.dtype
|
|
x = x.astype(communication_type) if communication_type is not None else x
|
|
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
|
|
|
if all_reduce_size <= 0:
|
|
return tree_map(_average, gradients)
|
|
|
|
else:
|
|
flat_grads = tree_flatten(gradients)
|
|
if len(flat_grads) == 0:
|
|
return gradients
|
|
|
|
# Extract some info for the gradient
|
|
keys = [k for k, _ in flat_grads]
|
|
shapes = [v.shape for _, v in flat_grads]
|
|
sizes = [v.size for _, v in flat_grads]
|
|
dtypes = [v.dtype for _, v in flat_grads]
|
|
|
|
# We can't group them if they have mixed types
|
|
if not all(dt == dtypes[0] for dt in dtypes):
|
|
return average_gradients(gradients, group, 0, communication_type)
|
|
itemsize = (
|
|
communication_type.size
|
|
if communication_type is not None
|
|
else dtypes[0].size
|
|
)
|
|
|
|
# Gather the gradients in groups that are just above or equal to all_reduce_size
|
|
grad_groups = []
|
|
grad_group = []
|
|
grad_group_size = 0
|
|
for i in range(len(keys)):
|
|
grad_group.append(i)
|
|
grad_group_size += sizes[i] * itemsize
|
|
if grad_group_size >= all_reduce_size:
|
|
grad_groups.append(grad_group)
|
|
grad_group = []
|
|
grad_group_size = 0
|
|
if grad_group:
|
|
grad_groups.append(grad_group)
|
|
grad_group = []
|
|
|
|
# Concatenate-reduce-split
|
|
new_flat_grads = []
|
|
for grad_group in grad_groups:
|
|
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
|
|
big_grad = mx.concatenate(
|
|
[flat_grads[i][1].reshape(-1) for i in grad_group]
|
|
)
|
|
big_grad = _average(big_grad)
|
|
big_grad = mx.split(big_grad, indices[1:-1])
|
|
new_flat_grads.extend(
|
|
(keys[j], big_grad[i].reshape(shapes[j]))
|
|
for i, j in enumerate(grad_group)
|
|
)
|
|
|
|
return tree_unflatten(new_flat_grads)
|