MLX
Loading...
Searching...
No Matches
transforms.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
6
7namespace mlx::core {
8
9void async_eval(std::vector<array> outputs);
10
11void eval(std::vector<array> outputs);
12
13template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
14void eval(Arrays&&... outputs) {
15 eval(std::vector<array>{std::forward<Arrays>(outputs)...});
16}
17
25std::pair<std::vector<array>, std::vector<array>> vjp(
26 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
27 const std::vector<array>& primals,
28 const std::vector<array>& cotangents);
29
33std::pair<array, array> vjp(
34 const std::function<array(const array&)>& fun,
35 const array& primal,
36 const array& cotangent);
37
45std::pair<std::vector<array>, std::vector<array>> jvp(
46 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
47 const std::vector<array>& primals,
48 const std::vector<array>& tangents);
49
53std::pair<array, array> jvp(
54 const std::function<array(const array&)>& fun,
55 const array& primal,
56 const array& tangent);
57
58// Return type of general value_and_grad: a function which takes an input
59// vector of arrays and returns a pair of vectors of arrays one for the
60// values and one for the gradients wrt the first value.
62 std::function<std::pair<std::vector<array>, std::vector<array>>(
63 const std::vector<array>&)>;
64using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
65 const std::vector<array>&)>;
66
72 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
73 const std::vector<int>& argnums);
74
80 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
81 int argnum = 0) {
82 return value_and_grad(fun, std::vector<int>{argnum});
83}
84
89std::function<std::pair<array, array>(const array&)> inline value_and_grad(
90 const std::function<array(const array&)>& fun) {
91 return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
92}
93
95 const std::function<array(const std::vector<array>&)>& fun,
96 const std::vector<int>& argnums) {
97 return [fun, argnums](auto inputs) {
98 auto result = value_and_grad(
99 [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
100 argnums)(inputs);
101
102 return std::make_pair(result.first[0], result.second);
103 };
104}
105
107 const std::function<array(const std::vector<array>&)>& fun,
108 int argnum = 0) {
109 return value_and_grad(fun, std::vector<int>{argnum});
110}
111
120std::function<std::vector<array>(const std::vector<array>&)> inline grad(
121 const std::function<array(const std::vector<array>&)>& fun,
122 const std::vector<int>& argnums) {
123 auto fn = value_and_grad(fun, argnums);
124 return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
125}
126
135std::function<std::vector<array>(const std::vector<array>&)> inline grad(
136 const std::function<array(const std::vector<array>&)>& fun,
137 int argnum = 0) {
138 return grad(fun, std::vector<int>{argnum});
139}
140
144std::function<array(const array&)> inline grad(
145 const std::function<array(const array&)>& fun) {
146 auto fn = value_and_grad(fun);
147 return [fn](const array& input) { return fn(input).second; };
148}
149
153std::function<array(const array&)> vmap(
154 const std::function<array(const array&)>& fun,
155 int in_axis = 0,
156 int out_axis = 0);
157
161std::function<array(const array&, const array&)> vmap(
162 const std::function<array(const array&, const array&)>& fun,
163 int in_axis_a = 0,
164 int in_axis_b = 0,
165 int out_axis = 0);
166
176std::function<std::vector<array>(const std::vector<array>&)> vmap(
177 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
178 const std::vector<int>& in_axes = {},
179 const std::vector<int>& out_axes = {});
180
185std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
186 std::function<std::vector<array>(const std::vector<array>&)> fun,
187 std::function<std::vector<array>(
188 const std::vector<array>&,
189 const std::vector<array>&,
190 const std::vector<array>&)> fun_vjp);
191
196std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
197 std::function<std::vector<array>(const std::vector<array>&)> fun);
198
199} // namespace mlx::core
Definition array.h:20
Definition allocator.h:7
void async_eval(std::vector< array > outputs)
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
std::function< std::vector< array >(const std::vector< array > &) checkpoint)(std::function< std::vector< array >(const std::vector< array > &)> fun)
Checkpoint the gradient of a function.
std::function< std::pair< array, std::vector< array > >( const std::vector< array > &)> SimpleValueAndGradFn
Definition transforms.h:64
std::function< std::pair< array, array >(const array &) value_and_grad)(const std::function< array(const array &)> &fun)
Returns a function which computes the value and gradient of the unary input function.
Definition transforms.h:89
std::function< std::vector< array >(const std::vector< array > &) custom_vjp)(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
Return the results of calling fun with args but if their vjp is computed it will be computed by fun_v...
void eval(std::vector< array > outputs)
std::function< array(const array &) vmap)(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
std::function< std::vector< array >(const std::vector< array > &) grad)(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
Returns a function which computes the gradient of the input function with respect to a vector of inpu...
Definition transforms.h:120
std::function< std::pair< std::vector< array >, std::vector< array > >( const std::vector< array > &)> ValueAndGradFn
Definition transforms.h:61
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:565