mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Add a compile option to nn.value_and_grad
This commit is contained in:
parent
db40990d33
commit
1e4d3c7fb2
@ -1,11 +1,23 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from typing import Callable
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
@contextmanager
|
||||||
|
def updated_model(model: "mlx.nn.Module", *parameters: Any):
|
||||||
|
old_state = model.parameters()
|
||||||
|
try:
|
||||||
|
for p in parameters:
|
||||||
|
model.update(p)
|
||||||
|
yield model
|
||||||
|
finally:
|
||||||
|
model.update(old_state)
|
||||||
|
|
||||||
|
|
||||||
|
def value_and_grad(model: "mlx.nn.Module", fn: Callable, compile: bool = False):
|
||||||
"""Transform the passed function ``fn`` to a function that computes the
|
"""Transform the passed function ``fn`` to a function that computes the
|
||||||
gradients of ``fn`` wrt the model's trainable parameters and also its
|
gradients of ``fn`` wrt the model's trainable parameters and also its
|
||||||
value.
|
value.
|
||||||
@ -14,20 +26,29 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
|||||||
model (mlx.nn.Module): The model whose trainable parameters to compute
|
model (mlx.nn.Module): The model whose trainable parameters to compute
|
||||||
gradients for
|
gradients for
|
||||||
fn (Callable): The scalar function to compute gradients for
|
fn (Callable): The scalar function to compute gradients for
|
||||||
|
compile (bool): Whether to "compile" the function before returning it.
|
||||||
|
Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A callable that returns the value of ``fn`` and the gradients wrt the
|
A callable that returns the value of ``fn`` and the gradients wrt the
|
||||||
trainable parameters of ``model``
|
trainable parameters of ``model``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner_fn(params, *args, **kwargs):
|
def inner_fn(trainable_parameters, other_parameters, *args, **kwargs):
|
||||||
model.update(params)
|
with updated_model(model, other_parameters, trainable_parameters):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
value_grad_fn = mx.value_and_grad(inner_fn)
|
value_grad_fn = mx.value_and_grad(inner_fn)
|
||||||
|
if compile:
|
||||||
|
value_grad_fn = mx.compile(value_grad_fn)
|
||||||
|
|
||||||
def wrapped_value_grad_fn(*args, **kwargs):
|
def wrapped_value_grad_fn(*args, **kwargs):
|
||||||
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
value, grad = value_grad_fn(
|
||||||
|
model.trainable_parameters(),
|
||||||
|
model.non_trainable_parameters(),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
return value, grad
|
return value, grad
|
||||||
|
|
||||||
return wrapped_value_grad_fn
|
return wrapped_value_grad_fn
|
||||||
|
Loading…
Reference in New Issue
Block a user