mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add a compile option to nn.value_and_grad
This commit is contained in:
		@@ -1,11 +1,23 @@
 | 
				
			|||||||
# Copyright © 2023 Apple Inc.
 | 
					# Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Callable
 | 
					from contextlib import contextmanager
 | 
				
			||||||
 | 
					from typing import Any, Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import mlx.core as mx
 | 
					import mlx.core as mx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
 | 
					@contextmanager
 | 
				
			||||||
 | 
					def updated_model(model: "mlx.nn.Module", *parameters: Any):
 | 
				
			||||||
 | 
					    old_state = model.parameters()
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        for p in parameters:
 | 
				
			||||||
 | 
					            model.update(p)
 | 
				
			||||||
 | 
					        yield model
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        model.update(old_state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def value_and_grad(model: "mlx.nn.Module", fn: Callable, compile: bool = False):
 | 
				
			||||||
    """Transform the passed function ``fn`` to a function that computes the
 | 
					    """Transform the passed function ``fn`` to a function that computes the
 | 
				
			||||||
    gradients of ``fn`` wrt the model's trainable parameters and also its
 | 
					    gradients of ``fn`` wrt the model's trainable parameters and also its
 | 
				
			||||||
    value.
 | 
					    value.
 | 
				
			||||||
@@ -14,20 +26,29 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
 | 
				
			|||||||
        model (mlx.nn.Module): The model whose trainable parameters to compute
 | 
					        model (mlx.nn.Module): The model whose trainable parameters to compute
 | 
				
			||||||
                               gradients for
 | 
					                               gradients for
 | 
				
			||||||
        fn (Callable): The scalar function to compute gradients for
 | 
					        fn (Callable): The scalar function to compute gradients for
 | 
				
			||||||
 | 
					        compile (bool): Whether to "compile" the function before returning it.
 | 
				
			||||||
 | 
					                        Default: ``False``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
        A callable that returns the value of ``fn`` and the gradients wrt the
 | 
					        A callable that returns the value of ``fn`` and the gradients wrt the
 | 
				
			||||||
        trainable parameters of ``model``
 | 
					        trainable parameters of ``model``
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def inner_fn(params, *args, **kwargs):
 | 
					    def inner_fn(trainable_parameters, other_parameters, *args, **kwargs):
 | 
				
			||||||
        model.update(params)
 | 
					        with updated_model(model, other_parameters, trainable_parameters):
 | 
				
			||||||
            return fn(*args, **kwargs)
 | 
					            return fn(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    value_grad_fn = mx.value_and_grad(inner_fn)
 | 
					    value_grad_fn = mx.value_and_grad(inner_fn)
 | 
				
			||||||
 | 
					    if compile:
 | 
				
			||||||
 | 
					        value_grad_fn = mx.compile(value_grad_fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def wrapped_value_grad_fn(*args, **kwargs):
 | 
					    def wrapped_value_grad_fn(*args, **kwargs):
 | 
				
			||||||
        value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
 | 
					        value, grad = value_grad_fn(
 | 
				
			||||||
 | 
					            model.trainable_parameters(),
 | 
				
			||||||
 | 
					            model.non_trainable_parameters(),
 | 
				
			||||||
 | 
					            *args,
 | 
				
			||||||
 | 
					            **kwargs,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        return value, grad
 | 
					        return value, grad
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return wrapped_value_grad_fn
 | 
					    return wrapped_value_grad_fn
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user