// Copyright © 2023-2024 Apple Inc. #pragma once #include "mlx/array.h" namespace mlx::core { void async_eval(std::vector outputs); void eval(std::vector outputs); template > void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); } /** * Computes the output and vector-Jacobian product (VJP) of a function. * * Computes the vector-Jacobian product of the vector of cotangents with the * Jacobian of the function evaluated at the primals. Returns a pair of * vectors of output arrays and VJP arrays. **/ std::pair, std::vector> vjp( const std::function(const std::vector&)>& fun, const std::vector& primals, const std::vector& cotangents); /** * Computes the output and vector-Jacobian product (VJP) of a unary function. */ std::pair vjp( const std::function& fun, const array& primal, const array& cotangent); /** * Computes the output and Jacobian-vector product (JVP) of a function. * * Computes the Jacobian-vector product of the Jacobian of the function * evaluated at the primals with the vector of tangents. Returns a pair of * vectors of output arrays and JVP arrays. **/ std::pair, std::vector> jvp( const std::function(const std::vector&)>& fun, const std::vector& primals, const std::vector& tangents); /** * Computes the output and Jacobian-vector product (JVP) of a unary function. */ std::pair jvp( const std::function& fun, const array& primal, const array& tangent); // Return type of general value_and_grad: a function which takes an input // vector of arrays and returns a pair of vectors of arrays one for the // values and one for the gradients wrt the first value. using ValueAndGradFn = std::function, std::vector>( const std::vector&)>; using SimpleValueAndGradFn = std::function>( const std::vector&)>; /** * Returns a function which computes the value and gradient of the input * function with respect to a vector of input arrays. **/ ValueAndGradFn value_and_grad( const std::function(const std::vector&)>& fun, const std::vector& argnums); /** * Returns a function which computes the value and gradient of the input * function with respect to a single input array. **/ ValueAndGradFn inline value_and_grad( const std::function(const std::vector&)>& fun, int argnum = 0) { return value_and_grad(fun, std::vector{argnum}); } /** * Returns a function which computes the value and gradient of the unary * input function. **/ std::function(const array&)> inline value_and_grad( const std::function& fun) { return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); }; } SimpleValueAndGradFn inline value_and_grad( const std::function&)>& fun, const std::vector& argnums) { return [fun, argnums](auto inputs) { auto result = value_and_grad( [fun](auto inputs) { return std::vector{fun(inputs)}; }, argnums)(inputs); return std::make_pair(result.first[0], result.second); }; } SimpleValueAndGradFn inline value_and_grad( const std::function&)>& fun, int argnum = 0) { return value_and_grad(fun, std::vector{argnum}); } /** * Returns a function which computes the gradient of the input function with * respect to a vector of input arrays. * * The function being differentiated takes a vector of arrays and returns an * array. The vector of `argnums` specifies which the arguments to compute * the gradient with respect to. At least one argument must be specified. **/ std::function(const std::vector&)> inline grad( const std::function&)>& fun, const std::vector& argnums) { auto fn = value_and_grad(fun, argnums); return [fn](const std::vector& inputs) { return fn(inputs).second; }; } /** * Returns a function which computes the gradient of the input function with * respect to a single input array. * * The function being differentiated takes a vector of arrays and returns an * array. The optional `argnum` index specifies which the argument to compute * the gradient with respect to and defaults to 0. **/ std::function(const std::vector&)> inline grad( const std::function&)>& fun, int argnum = 0) { return grad(fun, std::vector{argnum}); } /** * Returns a function which computes the gradient of the unary input function. **/ std::function inline grad( const std::function& fun) { auto fn = value_and_grad(fun); return [fn](const array& input) { return fn(input).second; }; } /** * Automatically vectorize a unary function over the requested axes. */ std::function vmap( const std::function& fun, int in_axis = 0, int out_axis = 0); /** * Automatically vectorize a binary function over the requested axes. */ std::function vmap( const std::function& fun, int in_axis_a = 0, int in_axis_b = 0, int out_axis = 0); /** * Automatically vectorize a function over the requested axes. * * The input function to `vmap` takes as an argument a vector of arrays and * returns a vector of arrays. Optionally specify the axes to vectorize over * with `in_axes` and `out_axes`, otherwise a default of 0 is used. * Returns a vectorized function with the same signature as the input * function. */ std::function(const std::vector&)> vmap( const std::function(const std::vector&)>& fun, const std::vector& in_axes = {}, const std::vector& out_axes = {}); /** * Return the results of calling fun with args but if their vjp is computed it * will be computed by fun_vjp. */ std::function(const std::vector&)> custom_vjp( std::function(const std::vector&)> fun, std::function( const std::vector&, const std::vector&, const std::vector&)> fun_vjp); /** * Checkpoint the gradient of a function. Namely, discard all intermediate * state and recalculate it when we need to compute the gradient. */ std::function(const std::vector&)> checkpoint( std::function(const std::vector&)> fun); } // namespace mlx::core