MLX
 
Loading...
Searching...
No Matches
transforms.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11void async_eval(std::vector<array> outputs);
12
13template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
14void async_eval(Arrays&&... outputs) {
15 async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
16}
17
18void eval(std::vector<array> outputs);
19
20template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
21void eval(Arrays&&... outputs) {
22 eval(std::vector<array>{std::forward<Arrays>(outputs)...});
23}
24
32std::pair<std::vector<array>, std::vector<array>> vjp(
33 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
34 const std::vector<array>& primals,
35 const std::vector<array>& cotangents);
36
40std::pair<array, array> vjp(
41 const std::function<array(const array&)>& fun,
42 const array& primal,
43 const array& cotangent);
44
52std::pair<std::vector<array>, std::vector<array>> jvp(
53 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
54 const std::vector<array>& primals,
55 const std::vector<array>& tangents);
56
60std::pair<array, array> jvp(
61 const std::function<array(const array&)>& fun,
62 const array& primal,
63 const array& tangent);
64
65// Return type of general value_and_grad: a function which takes an input
66// vector of arrays and returns a pair of vectors of arrays one for the
67// values and one for the gradients wrt the first value.
69 std::function<std::pair<std::vector<array>, std::vector<array>>(
70 const std::vector<array>&)>;
71using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
72 const std::vector<array>&)>;
73
79 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
80 const std::vector<int>& argnums);
81
87 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
88 int argnum = 0) {
89 return value_and_grad(fun, std::vector<int>{argnum});
90}
91
96std::function<std::pair<array, array>(const array&)> inline value_and_grad(
97 const std::function<array(const array&)>& fun) {
98 return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
99}
100
102 const std::function<array(const std::vector<array>&)>& fun,
103 const std::vector<int>& argnums) {
104 return [fun, argnums](auto inputs) {
105 auto result = value_and_grad(
106 [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
107 argnums)(inputs);
108
109 return std::make_pair(result.first[0], result.second);
110 };
111}
112
114 const std::function<array(const std::vector<array>&)>& fun,
115 int argnum = 0) {
116 return value_and_grad(fun, std::vector<int>{argnum});
117}
118
127std::function<std::vector<array>(const std::vector<array>&)> inline grad(
128 const std::function<array(const std::vector<array>&)>& fun,
129 const std::vector<int>& argnums) {
130 auto fn = value_and_grad(fun, argnums);
131 return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
132}
133
142std::function<std::vector<array>(const std::vector<array>&)> inline grad(
143 const std::function<array(const std::vector<array>&)>& fun,
144 int argnum = 0) {
145 return grad(fun, std::vector<int>{argnum});
146}
147
151std::function<array(const array&)> inline grad(
152 const std::function<array(const array&)>& fun) {
153 auto fn = value_and_grad(fun);
154 return [fn](const array& input) { return fn(input).second; };
155}
156
160std::function<array(const array&)> vmap(
161 const std::function<array(const array&)>& fun,
162 int in_axis = 0,
163 int out_axis = 0);
164
168std::function<array(const array&, const array&)> vmap(
169 const std::function<array(const array&, const array&)>& fun,
170 int in_axis_a = 0,
171 int in_axis_b = 0,
172 int out_axis = 0);
173
183std::function<std::vector<array>(const std::vector<array>&)> vmap(
184 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
185 const std::vector<int>& in_axes = {},
186 const std::vector<int>& out_axes = {});
187
197std::function<std::vector<array>(const std::vector<array>&)> custom_function(
198 std::function<std::vector<array>(const std::vector<array>&)> fun,
199 std::optional<std::function<std::vector<array>(
200 const std::vector<array>&,
201 const std::vector<array>&,
202 const std::vector<array>&)>> fun_vjp = std::nullopt,
203 std::optional<std::function<std::vector<array>(
204 const std::vector<array>&,
205 const std::vector<array>&,
206 const std::vector<int>&)>> fun_jvp = std::nullopt,
207 std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
208 const std::vector<array>&,
209 const std::vector<int>&)>> fun_vmap = std::nullopt);
210
215std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
216 std::function<std::vector<array>(const std::vector<array>&)> fun,
217 std::function<std::vector<array>(
218 const std::vector<array>&,
219 const std::vector<array>&,
220 const std::vector<array>&)> fun_vjp);
221
226std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
227 std::function<std::vector<array>(const std::vector<array>&)> fun);
228
229} // namespace mlx::core
Definition array.h:24
Definition allocator.h:7
void async_eval(std::vector< array > outputs)
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
std::function< std::pair< array, std::vector< array > >( const std::vector< array > &)> SimpleValueAndGradFn
Definition transforms.h:71
std::function< std::vector< array >(const std::vector< array > &)> grad(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
Returns a function which computes the gradient of the input function with respect to a vector of inpu...
Definition transforms.h:127
std::function< std::vector< array >(const std::vector< array > &)> checkpoint(std::function< std::vector< array >(const std::vector< array > &)> fun)
Checkpoint the gradient of a function.
void eval(std::vector< array > outputs)
std::function< std::vector< array >(const std::vector< array > &)> custom_function(std::function< std::vector< array >(const std::vector< array > &)> fun, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> > fun_vjp=std::nullopt, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> > fun_jvp=std::nullopt, std::optional< std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> > fun_vmap=std::nullopt)
Redefine the transformations of fun according to the provided functions.
std::function< std::vector< array >(const std::vector< array > &)> custom_vjp(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
Return a function that behaves exactly like fun but if the vjp of the results is computed fun_vjp wil...
std::function< std::pair< std::vector< array >, std::vector< array > >( const std::vector< array > &)> ValueAndGradFn
Definition transforms.h:68
ValueAndGradFn value_and_grad(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
Returns a function which computes the value and gradient of the input function with respect to a vect...
std::function< array(const array &)> vmap(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:630