mlx/mlx/transforms.h
Awni Hannun d75ae52ecd
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00

198 lines
6.8 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void eval(const std::vector<array>& outputs);
template <typename... Arrays>
void eval(Arrays... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}
/**
* Computes the output and vector-Jacobian product (VJP) of a function.
*
* Computes the vector-Jacobian product of the vector of cotangents with the
* Jacobian of the function evaluated at the primals. Returns a pair of
* vectors of output arrays and VJP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents);
/**
* Computes the output and vector-Jacobian product (VJP) of a unary function.
*/
std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& cotangent);
/**
* Computes the output and Jacobian-vector product (JVP) of a function.
*
* Computes the Jacobian-vector product of the Jacobian of the function
* evaluated at the primals with the vector of tangents. Returns a pair of
* vectors of output arrays and JVP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> jvp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents);
/**
* Computes the output and Jacobian-vector product (JVP) of a unary function.
*/
std::pair<array, array> jvp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& tangent);
// Return type of general value_and_grad: a function which takes an input
// vector of arrays and returns a pair of vectors of arrays one for the
// values and one for the gradients wrt the first value.
using ValueAndGradFn =
std::function<std::pair<std::vector<array>, std::vector<array>>(
const std::vector<array>&)>;
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
const std::vector<array>&)>;
/**
* Returns a function which computes the value and gradient of the input
* function with respect to a vector of input arrays.
**/
ValueAndGradFn value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& argnums);
/**
* Returns a function which computes the value and gradient of the input
* function with respect to a single input array.
**/
ValueAndGradFn inline value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the value and gradient of the unary
* input function.
**/
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
const std::function<array(const array&)>& fun) {
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
return [fun, argnums](auto inputs) {
auto result = value_and_grad(
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
argnums)(inputs);
return std::make_pair(result.first[0], result.second);
};
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the input function with
* respect to a vector of input arrays.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The vector of `argnums` specifies which the arguments to compute
* the gradient with respect to. At least one argument must be specified.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
auto fn = value_and_grad(fun, argnums);
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
}
/**
* Returns a function which computes the gradient of the input function with
* respect to a single input array.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The optional `argnum` index specifies which the argument to compute
* the gradient with respect to and defaults to 0.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the unary input function.
**/
std::function<array(const array&)> inline grad(
const std::function<array(const array&)>& fun) {
auto fn = value_and_grad(fun);
return [fn](const array& input) { return fn(input).second; };
}
/**
* Automatically vectorize a unary function over the requested axes.
*/
std::function<array(const array&)> vmap(
const std::function<array(const array&)>& fun,
int in_axis = 0,
int out_axis = 0);
/**
* Automatically vectorize a binary function over the requested axes.
*/
std::function<array(const array&, const array&)> vmap(
const std::function<array(const array&, const array&)>& fun,
int in_axis_a = 0,
int in_axis_b = 0,
int out_axis = 0);
/**
* Automatically vectorize a function over the requested axes.
*
* The input function to `vmap` takes as an argument a vector of arrays and
* returns a vector of arrays. Optionally specify the axes to vectorize over
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
* Returns a vectorized function with the same signature as the input
* function.
*/
std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {});
/**
* Return the results of calling fun with args but if their vjp is computed it
* will be computed by fun_vjp.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp);
/**
* Checkpoint the gradient of a function. Namely, discard all intermediate
* state and recalculate it when we need to compute the gradient.
*/
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun);
} // namespace mlx::core