2024-01-31 08:04:45 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2024-01-31 08:04:45 +08:00
|
|
|
from functools import wraps
|
2023-11-30 02:30:41 +08:00
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
2024-01-31 08:04:45 +08:00
|
|
|
from .layers.base import Module
|
2023-11-30 02:30:41 +08:00
|
|
|
|
2024-01-31 08:04:45 +08:00
|
|
|
|
|
|
|
def value_and_grad(model: Module, fn: Callable):
|
2023-11-30 02:30:41 +08:00
|
|
|
"""Transform the passed function ``fn`` to a function that computes the
|
|
|
|
gradients of ``fn`` wrt the model's trainable parameters and also its
|
|
|
|
value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (mlx.nn.Module): The model whose trainable parameters to compute
|
|
|
|
gradients for
|
|
|
|
fn (Callable): The scalar function to compute gradients for
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A callable that returns the value of ``fn`` and the gradients wrt the
|
|
|
|
trainable parameters of ``model``
|
|
|
|
"""
|
|
|
|
|
|
|
|
def inner_fn(params, *args, **kwargs):
|
|
|
|
model.update(params)
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
|
|
value_grad_fn = mx.value_and_grad(inner_fn)
|
|
|
|
|
2024-01-31 08:04:45 +08:00
|
|
|
@wraps(fn)
|
2023-11-30 02:30:41 +08:00
|
|
|
def wrapped_value_grad_fn(*args, **kwargs):
|
|
|
|
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
|
|
|
return value, grad
|
|
|
|
|
|
|
|
return wrapped_value_grad_fn
|
2024-01-31 08:04:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
def checkpoint(module: Module, fn: Callable = None):
|
|
|
|
"""Transform the passed callable to one that performs gradient
|
|
|
|
checkpointing with respect to the trainable parameters of the module (and
|
|
|
|
the callable's inputs).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
module (mlx.nn.Module): The module for whose parameters we will be
|
|
|
|
performing gradient checkpointing.
|
|
|
|
fn (Callable, optional): The function to checkpoint. If not provided it
|
|
|
|
defaults to the provided module.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A callable that saves the inputs and outputs during the forward pass
|
|
|
|
and recomputes all intermediate states during the backward pass.
|
|
|
|
"""
|
|
|
|
if fn is None:
|
|
|
|
# Capturing module instead of module.__call__ allows someone to
|
|
|
|
# monkey-patch __call__ later on and the correct method will be used
|
|
|
|
fn = module
|
|
|
|
|
|
|
|
def inner_fn(params, *args, **kwargs):
|
|
|
|
module.update(params)
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
|
|
checkpointed_fn = mx.checkpoint(inner_fn)
|
|
|
|
|
|
|
|
@wraps(fn)
|
|
|
|
def wrapped_checkpointed_fn(*args, **kwargs):
|
|
|
|
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
|
|
|
|
|
|
|
return wrapped_checkpointed_fn
|