|
void | mlx::core::async_eval (std::vector< array > outputs) |
|
void | mlx::core::eval (std::vector< array > outputs) |
|
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>> |
void | mlx::core::eval (Arrays &&... outputs) |
|
std::pair< std::vector< array >, std::vector< array > > | mlx::core::vjp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents) |
| Computes the output and vector-Jacobian product (VJP) of a function.
|
|
std::pair< array, array > | mlx::core::vjp (const std::function< array(const array &)> &fun, const array &primal, const array &cotangent) |
| Computes the output and vector-Jacobian product (VJP) of a unary function.
|
|
std::pair< std::vector< array >, std::vector< array > > | mlx::core::jvp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents) |
| Computes the output and Jacobian-vector product (JVP) of a function.
|
|
std::pair< array, array > | mlx::core::jvp (const std::function< array(const array &)> &fun, const array &primal, const array &tangent) |
| Computes the output and Jacobian-vector product (JVP) of a unary function.
|
|
ValueAndGradFn | mlx::core::value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &argnums) |
| Returns a function which computes the value and gradient of the input function with respect to a vector of input arrays.
|
|
ValueAndGradFn | mlx::core::value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, int argnum=0) |
| Returns a function which computes the value and gradient of the input function with respect to a single input array.
|
|
SimpleValueAndGradFn | mlx::core::value_and_grad (const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums) |
|
SimpleValueAndGradFn | mlx::core::value_and_grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0) |
|
std::function< std::vector< array >(const std::vector< array > &)> | mlx::core::grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0) |
| Returns a function which computes the gradient of the input function with respect to a single input array.
|
|
std::function< array(const array &)> | mlx::core::grad (const std::function< array(const array &)> &fun) |
| Returns a function which computes the gradient of the unary input function.
|
|
std::function< array(const array &, const array &)> | mlx::core::vmap (const std::function< array(const array &, const array &)> &fun, int in_axis_a=0, int in_axis_b=0, int out_axis=0) |
| Automatically vectorize a binary function over the requested axes.
|
|
std::function< std::vector< array >(const std::vector< array > &)> | mlx::core::vmap (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &in_axes={}, const std::vector< int > &out_axes={}) |
| Automatically vectorize a function over the requested axes.
|
|
|
std::function< std::pair< array, array >(const array &) | mlx::core::value_and_grad )(const std::function< array(const array &)> &fun) |
| Returns a function which computes the value and gradient of the unary input function.
|
|
std::function< std::vector< array >(const std::vector< array > &) | mlx::core::grad )(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums) |
| Returns a function which computes the gradient of the input function with respect to a vector of input arrays.
|
|
std::function< array(const array &) | mlx::core::vmap )(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0) |
| Automatically vectorize a unary function over the requested axes.
|
|
std::function< std::vector< array >(const std::vector< array > &) | mlx::core::custom_vjp )(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp) |
| Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp.
|
|
std::function< std::vector< array >(const std::vector< array > &) | mlx::core::checkpoint )(std::function< std::vector< array >(const std::vector< array > &)> fun) |
| Checkpoint the gradient of a function.
|
|