MLX
Loading...
Searching...
No Matches
Namespaces | Typedefs | Functions | Variables
transforms.h File Reference
#include "mlx/array.h"

Go to the source code of this file.

Namespaces

namespace  mlx
 
namespace  mlx::core
 

Typedefs

using mlx::core::ValueAndGradFn
 
using mlx::core::SimpleValueAndGradFn
 

Functions

void mlx::core::async_eval (std::vector< array > outputs)
 
void mlx::core::eval (std::vector< array > outputs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void mlx::core::eval (Arrays &&... outputs)
 
std::pair< std::vector< array >, std::vector< array > > mlx::core::vjp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
 Computes the output and vector-Jacobian product (VJP) of a function.
 
std::pair< array, arraymlx::core::vjp (const std::function< array(const array &)> &fun, const array &primal, const array &cotangent)
 Computes the output and vector-Jacobian product (VJP) of a unary function.
 
std::pair< std::vector< array >, std::vector< array > > mlx::core::jvp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
 Computes the output and Jacobian-vector product (JVP) of a function.
 
std::pair< array, arraymlx::core::jvp (const std::function< array(const array &)> &fun, const array &primal, const array &tangent)
 Computes the output and Jacobian-vector product (JVP) of a unary function.
 
ValueAndGradFn mlx::core::value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 Returns a function which computes the value and gradient of the input function with respect to a vector of input arrays.
 
ValueAndGradFn mlx::core::value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, int argnum=0)
 Returns a function which computes the value and gradient of the input function with respect to a single input array.
 
SimpleValueAndGradFn mlx::core::value_and_grad (const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 
SimpleValueAndGradFn mlx::core::value_and_grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0)
 
std::function< std::vector< array >(const std::vector< array > &)> mlx::core::grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0)
 Returns a function which computes the gradient of the input function with respect to a single input array.
 
std::function< array(const array &)> mlx::core::grad (const std::function< array(const array &)> &fun)
 Returns a function which computes the gradient of the unary input function.
 
std::function< array(const array &, const array &)> mlx::core::vmap (const std::function< array(const array &, const array &)> &fun, int in_axis_a=0, int in_axis_b=0, int out_axis=0)
 Automatically vectorize a binary function over the requested axes.
 
std::function< std::vector< array >(const std::vector< array > &)> mlx::core::vmap (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &in_axes={}, const std::vector< int > &out_axes={})
 Automatically vectorize a function over the requested axes.
 

Variables

std::function< std::pair< array, array >(const array &) mlx::core::value_and_grad )(const std::function< array(const array &)> &fun)
 Returns a function which computes the value and gradient of the unary input function.
 
std::function< std::vector< array >(const std::vector< array > &) mlx::core::grad )(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 Returns a function which computes the gradient of the input function with respect to a vector of input arrays.
 
std::function< array(const array &) mlx::core::vmap )(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
 Automatically vectorize a unary function over the requested axes.
 
std::function< std::vector< array >(const std::vector< array > &) mlx::core::custom_vjp )(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
 Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp.
 
std::function< std::vector< array >(const std::vector< array > &) mlx::core::checkpoint )(std::function< std::vector< array >(const std::vector< array > &)> fun)
 Checkpoint the gradient of a function.