MLX
Loading...
Searching...
No Matches
transforms.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11void async_eval(std::vector<array> outputs);
12
13void eval(std::vector<array> outputs);
14
15template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
16void eval(Arrays&&... outputs) {
17 eval(std::vector<array>{std::forward<Arrays>(outputs)...});
18}
19
27std::pair<std::vector<array>, std::vector<array>> vjp(
28 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
29 const std::vector<array>& primals,
30 const std::vector<array>& cotangents);
31
35std::pair<array, array> vjp(
36 const std::function<array(const array&)>& fun,
37 const array& primal,
38 const array& cotangent);
39
47std::pair<std::vector<array>, std::vector<array>> jvp(
48 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
49 const std::vector<array>& primals,
50 const std::vector<array>& tangents);
51
55std::pair<array, array> jvp(
56 const std::function<array(const array&)>& fun,
57 const array& primal,
58 const array& tangent);
59
60// Return type of general value_and_grad: a function which takes an input
61// vector of arrays and returns a pair of vectors of arrays one for the
62// values and one for the gradients wrt the first value.
64 std::function<std::pair<std::vector<array>, std::vector<array>>(
65 const std::vector<array>&)>;
66using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
67 const std::vector<array>&)>;
68
74 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
75 const std::vector<int>& argnums);
76
82 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
83 int argnum = 0) {
84 return value_and_grad(fun, std::vector<int>{argnum});
85}
86
91std::function<std::pair<array, array>(const array&)> inline value_and_grad(
92 const std::function<array(const array&)>& fun) {
93 return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
94}
95
97 const std::function<array(const std::vector<array>&)>& fun,
98 const std::vector<int>& argnums) {
99 return [fun, argnums](auto inputs) {
100 auto result = value_and_grad(
101 [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
102 argnums)(inputs);
103
104 return std::make_pair(result.first[0], result.second);
105 };
106}
107
109 const std::function<array(const std::vector<array>&)>& fun,
110 int argnum = 0) {
111 return value_and_grad(fun, std::vector<int>{argnum});
112}
113
122std::function<std::vector<array>(const std::vector<array>&)> inline grad(
123 const std::function<array(const std::vector<array>&)>& fun,
124 const std::vector<int>& argnums) {
125 auto fn = value_and_grad(fun, argnums);
126 return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
127}
128
137std::function<std::vector<array>(const std::vector<array>&)> inline grad(
138 const std::function<array(const std::vector<array>&)>& fun,
139 int argnum = 0) {
140 return grad(fun, std::vector<int>{argnum});
141}
142
146std::function<array(const array&)> inline grad(
147 const std::function<array(const array&)>& fun) {
148 auto fn = value_and_grad(fun);
149 return [fn](const array& input) { return fn(input).second; };
150}
151
155std::function<array(const array&)> vmap(
156 const std::function<array(const array&)>& fun,
157 int in_axis = 0,
158 int out_axis = 0);
159
163std::function<array(const array&, const array&)> vmap(
164 const std::function<array(const array&, const array&)>& fun,
165 int in_axis_a = 0,
166 int in_axis_b = 0,
167 int out_axis = 0);
168
178std::function<std::vector<array>(const std::vector<array>&)> vmap(
179 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
180 const std::vector<int>& in_axes = {},
181 const std::vector<int>& out_axes = {});
182
192std::function<std::vector<array>(const std::vector<array>&)> custom_function(
193 std::function<std::vector<array>(const std::vector<array>&)> fun,
194 std::optional<std::function<std::vector<array>(
195 const std::vector<array>&,
196 const std::vector<array>&,
197 const std::vector<array>&)>> fun_vjp = std::nullopt,
198 std::optional<std::function<std::vector<array>(
199 const std::vector<array>&,
200 const std::vector<array>&,
201 const std::vector<int>&)>> fun_jvp = std::nullopt,
202 std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
203 const std::vector<array>&,
204 const std::vector<int>&)>> fun_vmap = std::nullopt);
205
210std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
211 std::function<std::vector<array>(const std::vector<array>&)> fun,
212 std::function<std::vector<array>(
213 const std::vector<array>&,
214 const std::vector<array>&,
215 const std::vector<array>&)> fun_vjp);
216
221std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
222 std::function<std::vector<array>(const std::vector<array>&)> fun);
223
224} // namespace mlx::core
Definition array.h:20
Definition allocator.h:7
void async_eval(std::vector< array > outputs)
std::pair< std::vector< array >, std::vector< array > > jvp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
Computes the output and Jacobian-vector product (JVP) of a function.
std::pair< std::vector< array >, std::vector< array > > vjp(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
Computes the output and vector-Jacobian product (VJP) of a function.
std::function< std::vector< array >(const std::vector< array > &) checkpoint)(std::function< std::vector< array >(const std::vector< array > &)> fun)
Checkpoint the gradient of a function.
std::function< std::pair< array, std::vector< array > >( const std::vector< array > &)> SimpleValueAndGradFn
Definition transforms.h:66
std::function< std::pair< array, array >(const array &) value_and_grad)(const std::function< array(const array &)> &fun)
Returns a function which computes the value and gradient of the unary input function.
Definition transforms.h:91
std::function< std::vector< array >(const std::vector< array > &) custom_vjp)(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
Return a function that behaves exactly like fun but if the vjp of the results is computed fun_vjp wil...
std::function< std::vector< array >(const std::vector< array > &) custom_function)(std::function< std::vector< array >(const std::vector< array > &)> fun, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> > fun_vjp=std::nullopt, std::optional< std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< int > &)> > fun_jvp=std::nullopt, std::optional< std::function< std::pair< std::vector< array >, std::vector< int > >(const std::vector< array > &, const std::vector< int > &)> > fun_vmap=std::nullopt)
Redefine the transformations of fun according to the provided functions.
void eval(std::vector< array > outputs)
std::function< array(const array &) vmap)(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
Automatically vectorize a unary function over the requested axes.
std::function< std::vector< array >(const std::vector< array > &) grad)(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
Returns a function which computes the gradient of the input function with respect to a vector of inpu...
Definition transforms.h:122
std::function< std::pair< std::vector< array >, std::vector< array > >( const std::vector< array > &)> ValueAndGradFn
Definition transforms.h:63
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:589