mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add a compile option to nn.value_and_grad
This commit is contained in:
parent
db40990d33
commit
1e4d3c7fb2
@ -1,11 +1,23 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from typing import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
@contextmanager
|
||||
def updated_model(model: "mlx.nn.Module", *parameters: Any):
|
||||
old_state = model.parameters()
|
||||
try:
|
||||
for p in parameters:
|
||||
model.update(p)
|
||||
yield model
|
||||
finally:
|
||||
model.update(old_state)
|
||||
|
||||
|
||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable, compile: bool = False):
|
||||
"""Transform the passed function ``fn`` to a function that computes the
|
||||
gradients of ``fn`` wrt the model's trainable parameters and also its
|
||||
value.
|
||||
@ -14,20 +26,29 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||
gradients for
|
||||
fn (Callable): The scalar function to compute gradients for
|
||||
compile (bool): Whether to "compile" the function before returning it.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A callable that returns the value of ``fn`` and the gradients wrt the
|
||||
trainable parameters of ``model``
|
||||
"""
|
||||
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(*args, **kwargs)
|
||||
def inner_fn(trainable_parameters, other_parameters, *args, **kwargs):
|
||||
with updated_model(model, other_parameters, trainable_parameters):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
value_grad_fn = mx.value_and_grad(inner_fn)
|
||||
if compile:
|
||||
value_grad_fn = mx.compile(value_grad_fn)
|
||||
|
||||
def wrapped_value_grad_fn(*args, **kwargs):
|
||||
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
||||
value, grad = value_grad_fn(
|
||||
model.trainable_parameters(),
|
||||
model.non_trainable_parameters(),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return value, grad
|
||||
|
||||
return wrapped_value_grad_fn
|
||||
|
Loading…
Reference in New Issue
Block a user