mlx/python/mlx/nn/utils.py
2024-09-16 18:17:21 -07:00

162 lines
5.7 KiB
Python

# Copyright © 2023-2024 Apple Inc.
from functools import reduce, wraps
from typing import Any, Callable, Optional
import mlx.core as mx
from ..utils import tree_flatten, tree_map, tree_unflatten
from .layers.base import Module
def value_and_grad(model: Module, fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
A callable that returns the value of ``fn`` and the gradients wrt the
trainable parameters of ``model``
"""
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(*args, **kwargs)
value_grad_fn = mx.value_and_grad(inner_fn)
@wraps(fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Optional[Callable] = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn
def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
):
"""Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.
Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return gradients
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)
else:
flat_grads = tree_flatten(gradients)
if len(flat_grads) == 0:
return gradients
# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]
# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
return average_gradients(gradients, group, 0, communication_type)
itemsize = (
communication_type.size
if communication_type is not None
else dtypes[0].size
)
# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []
# Concatenate-reduce-split
new_flat_grads = []
for grad_group in grad_groups:
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(-1) for i in grad_group]
)
big_grad = _average(big_grad)
big_grad = mx.split(big_grad, indices[1:-1])
new_flat_grads.extend(
(keys[j], big_grad[i].reshape(shapes[j]))
for i, j in enumerate(grad_group)
)
return tree_unflatten(new_flat_grads)