mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
Custom transforms (#1246)
This commit is contained in:
parent
a3c287354f
commit
5c1fa64fb0
@ -10,6 +10,7 @@ Transforms
|
|||||||
|
|
||||||
eval
|
eval
|
||||||
compile
|
compile
|
||||||
|
custom_function
|
||||||
disable_compile
|
disable_compile
|
||||||
enable_compile
|
enable_compile
|
||||||
grad
|
grad
|
||||||
|
@ -17,6 +17,10 @@ bool in_tracing() {
|
|||||||
return detail::InTracing::in_tracing();
|
return detail::InTracing::in_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool retain_graph() {
|
||||||
|
return detail::RetainGraph::retain_graph();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
@ -102,7 +106,7 @@ void array::eval() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_tracer() const {
|
bool array::is_tracer() const {
|
||||||
return array_desc_->is_tracer && in_tracing();
|
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||||
|
@ -36,7 +36,7 @@ DEFAULT(Ceil)
|
|||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Conjugate)
|
DEFAULT(Conjugate)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomTransforms)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(NumberOfElements)
|
DEFAULT(NumberOfElements)
|
||||||
|
@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(inputs[0]);
|
out.copy_shared_buffer(inputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomVJP::eval(
|
void CustomTransforms::eval(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
|
@ -52,7 +52,7 @@ DEFAULT(Convolution)
|
|||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT(Cos)
|
DEFAULT(Cos)
|
||||||
DEFAULT(Cosh)
|
DEFAULT(Cosh)
|
||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomTransforms)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT(Divide)
|
DEFAULT(Divide)
|
||||||
DEFAULT(NumberOfElements)
|
DEFAULT(NumberOfElements)
|
||||||
|
@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomVJP::eval_gpu(
|
void CustomTransforms::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
eval(inputs, outputs);
|
eval(inputs, outputs);
|
||||||
|
@ -42,7 +42,7 @@ NO_CPU(Convolution)
|
|||||||
NO_CPU(Copy)
|
NO_CPU(Copy)
|
||||||
NO_CPU(Cos)
|
NO_CPU(Cos)
|
||||||
NO_CPU(Cosh)
|
NO_CPU(Cosh)
|
||||||
NO_CPU_MULTI(CustomVJP)
|
NO_CPU_MULTI(CustomTransforms)
|
||||||
NO_CPU_MULTI(Depends)
|
NO_CPU_MULTI(Depends)
|
||||||
NO_CPU(Divide)
|
NO_CPU(Divide)
|
||||||
NO_CPU_MULTI(DivMod)
|
NO_CPU_MULTI(DivMod)
|
||||||
|
@ -43,7 +43,7 @@ NO_GPU(Convolution)
|
|||||||
NO_GPU(Copy)
|
NO_GPU(Copy)
|
||||||
NO_GPU(Cos)
|
NO_GPU(Cos)
|
||||||
NO_GPU(Cosh)
|
NO_GPU(Cosh)
|
||||||
NO_GPU_MULTI(CustomVJP)
|
NO_GPU_MULTI(CustomTransforms)
|
||||||
NO_GPU_MULTI(Depends)
|
NO_GPU_MULTI(Depends)
|
||||||
NO_GPU(Divide)
|
NO_GPU(Divide)
|
||||||
NO_GPU_MULTI(DivMod)
|
NO_GPU_MULTI(DivMod)
|
||||||
|
17
mlx/fast.cpp
17
mlx/fast.cpp
@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
|
|||||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||||
std::vector<array> vjp_outs;
|
std::vector<array> vjp_outs;
|
||||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||||
if (i < argnums.size() && i == argnums[j]) {
|
if (j < argnums.size() && i == argnums[j]) {
|
||||||
vjp_outs.push_back(vjps[i]);
|
vjp_outs.push_back(vjps[i]);
|
||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
|
|||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
std::vector<array> all_tangents;
|
||||||
std::vector<array> jvp_outs;
|
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
if (j < argnums.size() && i == argnums[j]) {
|
||||||
if (i < argnums.size() && i == argnums[j]) {
|
all_tangents.emplace_back(tangents[j++]);
|
||||||
jvp_outs.push_back(jvps[i]);
|
} else {
|
||||||
j++;
|
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return jvp_outs;
|
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
|
||||||
|
return jvps;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||||
|
@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
|
|||||||
return {{cosh(inputs[0], stream())}, axes};
|
return {{cosh(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> CustomVJP::vjp(
|
std::vector<array> CustomTransforms::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
|
// Extract the inputs to the VJP function
|
||||||
|
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||||
|
|
||||||
|
// Compute all the vjps
|
||||||
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
|
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
|
||||||
for (const auto& cot : cotangents) {
|
for (const auto& cot : cotangents) {
|
||||||
all_vjps.emplace_back(cot);
|
all_vjps.emplace_back(cot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Select the vjps requested
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
vjps.reserve(argnums.size());
|
vjps.reserve(argnums.size());
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
|
|||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> CustomTransforms::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Extract the inputs to the JVP function
|
||||||
|
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||||
|
|
||||||
|
// Compute the jvps
|
||||||
|
return jvp_fun_(inputs, tangents, argnums);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
|
||||||
|
const std::vector<array>& inputs_,
|
||||||
|
const std::vector<int>& axes_) {
|
||||||
|
// Extract the inputs to the vmap function
|
||||||
|
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
|
||||||
|
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
|
||||||
|
return vmap_fun_(inputs, axes);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Depends::vjp(
|
std::vector<array> Depends::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class CustomVJP : public Primitive {
|
class CustomTransforms : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit CustomVJP(
|
explicit CustomTransforms(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
|
int num_outputs,
|
||||||
std::function<std::vector<array>(
|
std::function<std::vector<array>(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&)> fun)
|
const std::vector<array>&)> vjp,
|
||||||
: Primitive(stream), vjp_fun_(std::move(fun)) {}
|
std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)> jvp,
|
||||||
|
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)> vmap)
|
||||||
|
: Primitive(stream),
|
||||||
|
num_outputs_(num_outputs),
|
||||||
|
vjp_fun_(std::move(vjp)),
|
||||||
|
jvp_fun_(std::move(jvp)),
|
||||||
|
vmap_fun_(std::move(vmap)) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
std::vector<array> vjp(
|
DEFINE_GRADS();
|
||||||
const std::vector<array>& primals,
|
DEFINE_VMAP();
|
||||||
const std::vector<array>& cotan,
|
DEFINE_PRINT(CustomTransforms);
|
||||||
const std::vector<int>& argnums,
|
|
||||||
const std::vector<array>& outputs) override;
|
|
||||||
|
|
||||||
DEFINE_PRINT(CustomVJP);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
|
|
||||||
|
int num_outputs_;
|
||||||
|
|
||||||
std::function<std::vector<array>(
|
std::function<std::vector<array>(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&)>
|
const std::vector<array>&)>
|
||||||
vjp_fun_;
|
vjp_fun_;
|
||||||
|
std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>
|
||||||
|
jvp_fun_;
|
||||||
|
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>
|
||||||
|
vmap_fun_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Depends : public Primitive {
|
class Depends : public Primitive {
|
||||||
|
@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
|
|||||||
// Initialize the static tracing counter from transforms_impl.h .
|
// Initialize the static tracing counter from transforms_impl.h .
|
||||||
//
|
//
|
||||||
// This is used to implement the in_tracing() function the returns true if we
|
// This is used to implement the in_tracing() function the returns true if we
|
||||||
// are currently under a function transformation.
|
// are currently under a function transformation and the retain_graph()
|
||||||
|
// function which returns true if we are forced to retain the graph during
|
||||||
|
// evaluation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
std::vector<array> vjps;
|
||||||
|
{
|
||||||
|
detail::RetainGraph retain;
|
||||||
|
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||||
|
}
|
||||||
// Accumulate the vector-jacobian products for each input
|
// Accumulate the vector-jacobian products for each input
|
||||||
for (int i = 0; i < argnums.size(); ++i) {
|
for (int i = 0; i < argnums.size(); ++i) {
|
||||||
auto in_id = a.inputs()[argnums[i]].id();
|
auto in_id = a.inputs()[argnums[i]].id();
|
||||||
@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
|
|||||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
std::function<std::vector<array>(
|
std::optional<std::function<std::vector<array>(
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&,
|
const std::vector<array>&,
|
||||||
const std::vector<array>&)> fun_vjp) {
|
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
|
||||||
|
std::optional<std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
|
||||||
|
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
|
||||||
|
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
|
||||||
|
return fun;
|
||||||
|
}
|
||||||
|
|
||||||
return [fun = std::move(fun),
|
return [fun = std::move(fun),
|
||||||
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
|
fun_vjp = std::move(fun_vjp),
|
||||||
|
fun_jvp = std::move(fun_jvp),
|
||||||
|
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
|
||||||
// Compute the outputs
|
// Compute the outputs
|
||||||
auto outputs = fun(args);
|
auto outputs = fun(args);
|
||||||
for (auto& out : outputs) {
|
for (auto& out : outputs) {
|
||||||
@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
|||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
std::move(shapes),
|
std::move(shapes),
|
||||||
dtypes,
|
dtypes,
|
||||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
std::make_shared<CustomTransforms>(
|
||||||
|
to_stream(s),
|
||||||
|
outputs.size(),
|
||||||
|
|
||||||
|
// We use the passed vjp function or compute it from the inputs and
|
||||||
|
// passed cotangents. Note that this may be less efficient than
|
||||||
|
// using `fun` directly because we may not be able to fully reuse
|
||||||
|
// the outputs of the forward pass.
|
||||||
|
fun_vjp.value_or(
|
||||||
|
[fun](auto primals, auto cotangents, auto outputs) {
|
||||||
|
auto [__, vjps] = vjp(fun, primals, cotangents);
|
||||||
|
return vjps;
|
||||||
|
}),
|
||||||
|
|
||||||
|
// We use the passed jvp function or compute it from the primals
|
||||||
|
// and tangents. Similarly we can't take full advantage of the
|
||||||
|
// argnums so it is best to use `fun` directly if we don't need a
|
||||||
|
// custom transform.
|
||||||
|
//
|
||||||
|
// TODO: Use stop_gradient to make full use of argnums and not
|
||||||
|
// waste computation.
|
||||||
|
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
||||||
|
std::vector<array> all_tangents;
|
||||||
|
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||||
|
if (j < argnums.size() && i == argnums[j]) {
|
||||||
|
all_tangents.emplace_back(tangents[j++]);
|
||||||
|
} else {
|
||||||
|
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto [__, jvps] = jvp(fun, primals, all_tangents);
|
||||||
|
return jvps;
|
||||||
|
}),
|
||||||
|
|
||||||
|
// Same as above, we use the passed vmap function or we compute it
|
||||||
|
// from `fun`. The output axes is selected to be all 0s which again
|
||||||
|
// may be suboptimal but the only thing we can do without any
|
||||||
|
// information for `fun`.
|
||||||
|
fun_vmap.value_or(
|
||||||
|
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
|
||||||
|
-> std::pair<std::vector<array>, std::vector<int>> {
|
||||||
|
std::vector<int> out_axes(out_size, 0);
|
||||||
|
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
|
||||||
|
})),
|
||||||
inputs);
|
inputs);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
|
std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&)> fun_vjp) {
|
||||||
|
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
||||||
auto vjp_fun = [fun](
|
auto vjp_fun = [fun](
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -179,8 +181,31 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
|||||||
const std::vector<int>& out_axes = {});
|
const std::vector<int>& out_axes = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the results of calling fun with args but if their vjp is computed it
|
* Redefine the transformations of `fun` according to the provided functions.
|
||||||
* will be computed by fun_vjp.
|
*
|
||||||
|
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
|
||||||
|
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
|
||||||
|
*
|
||||||
|
* If any transformation is not provided, then a default one is created by
|
||||||
|
* calling `vjp`, `jvp` and `vmap` on the function directly.
|
||||||
|
*/
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
|
std::optional<std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&)>> fun_vjp = std::nullopt,
|
||||||
|
std::optional<std::function<std::vector<array>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>> fun_jvp = std::nullopt,
|
||||||
|
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<int>&)>> fun_vmap = std::nullopt);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a function that behaves exactly like `fun` but if the vjp of the
|
||||||
|
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
|
||||||
*/
|
*/
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
|
@ -50,4 +50,20 @@ struct InTracing {
|
|||||||
static int tracing_counter;
|
static int tracing_counter;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RetainGraph {
|
||||||
|
RetainGraph() {
|
||||||
|
tracing_counter++;
|
||||||
|
}
|
||||||
|
~RetainGraph() {
|
||||||
|
tracing_counter--;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool retain_graph() {
|
||||||
|
return tracing_counter > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int tracing_counter;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::detail
|
} // namespace mlx::core::detail
|
||||||
|
@ -593,7 +593,454 @@ class PyCheckpointedFun {
|
|||||||
nb::callable fun_;
|
nb::callable fun_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PyCustomFunction is the class that implements the python decorator
|
||||||
|
* `mx.custom_function`.
|
||||||
|
*
|
||||||
|
* It implements a callable that instead of simply calling `fun` it creates a
|
||||||
|
* CustomTransforms primitive via the `custom_function` C++ op which allows us
|
||||||
|
* to redefine the vjp, jvp and vmap transformations.
|
||||||
|
*
|
||||||
|
* The implementation is verbose due to explicit handling of the destruction of
|
||||||
|
* various python objects to make sure that there is no double-free and that
|
||||||
|
* all of them are deleted while under GIL.
|
||||||
|
*
|
||||||
|
* Namely, for every one of the functions passed to the C++ `custom_function`
|
||||||
|
* we create a callable struct that holds the following python objects (when
|
||||||
|
* needed).
|
||||||
|
*
|
||||||
|
* - An nb::callable which holds the passed function or transform
|
||||||
|
* - An nb::object holding input structure, namely the `(args, kwargs)`
|
||||||
|
* passed to the function in order to be able to recreate the arguments
|
||||||
|
* from the input arrays.
|
||||||
|
* - A std::shared_ptr<nb::object> holding the output structure name the
|
||||||
|
* structure of the return value of `fun`. It is a shared_ptr so that it
|
||||||
|
* can be set when the function is called and then used in the `vjp`
|
||||||
|
* transform. We delete the object only when the shared_ptr is about to be
|
||||||
|
* deleted see `output_structure_.use_count() == 1` to make sure that the
|
||||||
|
* object is deleted under GIL.
|
||||||
|
*/
|
||||||
|
class PyCustomFunction {
|
||||||
|
public:
|
||||||
|
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||||
|
~PyCustomFunction() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
fun_.release().dec_ref();
|
||||||
|
if (vjp_fun_.has_value()) {
|
||||||
|
(*vjp_fun_).release().dec_ref();
|
||||||
|
}
|
||||||
|
if (jvp_fun_.has_value()) {
|
||||||
|
(*jvp_fun_).release().dec_ref();
|
||||||
|
}
|
||||||
|
if (vmap_fun_.has_value()) {
|
||||||
|
(*vmap_fun_).release().dec_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InnerFunction {
|
||||||
|
nb::callable fun_;
|
||||||
|
nb::object input_structure_;
|
||||||
|
std::shared_ptr<nb::object> output_structure_;
|
||||||
|
|
||||||
|
InnerFunction(
|
||||||
|
nb::callable fun,
|
||||||
|
nb::object input_structure,
|
||||||
|
std::shared_ptr<nb::object> output_structure)
|
||||||
|
: fun_(std::move(fun)),
|
||||||
|
input_structure_(std::move(input_structure)),
|
||||||
|
output_structure_(std::move(output_structure)) {}
|
||||||
|
~InnerFunction() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
fun_.release().dec_ref();
|
||||||
|
input_structure_.release().dec_ref();
|
||||||
|
if (output_structure_.use_count() == 1) {
|
||||||
|
output_structure_->release().dec_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
|
tree_unflatten_from_structure(input_structure_, inputs));
|
||||||
|
std::vector<array> outputs;
|
||||||
|
std::tie(outputs, *output_structure_) =
|
||||||
|
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InnerVJPFunction {
|
||||||
|
nb::callable vjp_fun_;
|
||||||
|
nb::object input_structure_;
|
||||||
|
std::shared_ptr<nb::object> output_structure_;
|
||||||
|
|
||||||
|
InnerVJPFunction(
|
||||||
|
nb::callable vjp_fun,
|
||||||
|
nb::object input_structure,
|
||||||
|
std::shared_ptr<nb::object> output_structure)
|
||||||
|
: vjp_fun_(std::move(vjp_fun)),
|
||||||
|
input_structure_(std::move(input_structure)),
|
||||||
|
output_structure_(std::move(output_structure)) {}
|
||||||
|
~InnerVJPFunction() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
vjp_fun_.release().dec_ref();
|
||||||
|
input_structure_.release().dec_ref();
|
||||||
|
if (output_structure_.use_count() == 1) {
|
||||||
|
output_structure_->release().dec_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> operator()(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
|
tree_unflatten_from_structure(input_structure_, primals));
|
||||||
|
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||||
|
auto new_cotangents =
|
||||||
|
tree_unflatten_from_structure(*output_structure_, cotangents);
|
||||||
|
auto new_outputs =
|
||||||
|
tree_unflatten_from_structure(*output_structure_, outputs);
|
||||||
|
|
||||||
|
if (args.size() == 1) {
|
||||||
|
return tree_flatten(
|
||||||
|
vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),
|
||||||
|
false);
|
||||||
|
} else {
|
||||||
|
return tree_flatten(
|
||||||
|
vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),
|
||||||
|
false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InnerJVPFunction {
|
||||||
|
nb::callable jvp_fun_;
|
||||||
|
nb::object input_structure_;
|
||||||
|
|
||||||
|
InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)
|
||||||
|
: jvp_fun_(std::move(jvp_fun)),
|
||||||
|
input_structure_(std::move(input_structure)) {}
|
||||||
|
~InnerJVPFunction() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
jvp_fun_.release().dec_ref();
|
||||||
|
input_structure_.release().dec_ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> operator()(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
|
tree_unflatten_from_structure(input_structure_, primals));
|
||||||
|
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||||
|
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||||
|
if (kwargs.size() > 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[custom jvp] Function should only accept positional arguments");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a new pytree which has tangents or None when a tangent is not
|
||||||
|
// available.
|
||||||
|
std::vector<bool> have_tangents(primals.size(), false);
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
have_tangents[arg] = true;
|
||||||
|
}
|
||||||
|
int array_index = 0;
|
||||||
|
int tangent_index = 0;
|
||||||
|
auto new_tangents =
|
||||||
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||||
|
if (nb::isinstance<array>(element) &&
|
||||||
|
have_tangents[array_index++]) {
|
||||||
|
return nb::cast(tangents[tangent_index++]);
|
||||||
|
} else {
|
||||||
|
return nb::none();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (args.size() == 1) {
|
||||||
|
return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);
|
||||||
|
} else {
|
||||||
|
return tree_flatten(jvp_fun_(args, new_tangents), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InnerVmapFunction {
|
||||||
|
nb::callable vmap_fun_;
|
||||||
|
nb::object input_structure_;
|
||||||
|
|
||||||
|
InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)
|
||||||
|
: vmap_fun_(std::move(vmap_fun)),
|
||||||
|
input_structure_(std::move(input_structure)) {}
|
||||||
|
~InnerVmapFunction() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
vmap_fun_.release().dec_ref();
|
||||||
|
input_structure_.release().dec_ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> operator()(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
|
tree_unflatten_from_structure(input_structure_, inputs));
|
||||||
|
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||||
|
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||||
|
if (kwargs.size() > 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[custom vmap] Function should only accept positional arguments");
|
||||||
|
}
|
||||||
|
|
||||||
|
int arr_index;
|
||||||
|
auto new_axes =
|
||||||
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||||
|
int axis = axes[arr_index++];
|
||||||
|
if (nb::isinstance<array>(element) && axis >= 0) {
|
||||||
|
return nb::cast(axis);
|
||||||
|
} else {
|
||||||
|
return nb::none();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
nb::object result;
|
||||||
|
if (args.size() == 1) {
|
||||||
|
result = vmap_fun_(args[0], new_axes[0]);
|
||||||
|
} else {
|
||||||
|
result = vmap_fun_(args, new_axes);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nb::isinstance<nb::tuple>(result)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||||
|
}
|
||||||
|
nb::tuple result_tuple = nb::cast<nb::tuple>(result);
|
||||||
|
if (result_tuple.size() != 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
std::vector<int> output_axes;
|
||||||
|
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||||
|
if (nb::isinstance<array>(objects[0])) {
|
||||||
|
outputs.push_back(nb::cast<array>(objects[0]));
|
||||||
|
output_axes.push_back(
|
||||||
|
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return {outputs, output_axes};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
|
if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&
|
||||||
|
!vmap_fun_.has_value()) {
|
||||||
|
return fun_(*args, **kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the inputs and their structure in capturable vars
|
||||||
|
std::vector<array> input_arrays;
|
||||||
|
nb::object input_structure;
|
||||||
|
auto full_args = nb::make_tuple(args, kwargs);
|
||||||
|
std::tie(input_arrays, input_structure) =
|
||||||
|
tree_flatten_with_structure(full_args, false);
|
||||||
|
|
||||||
|
// The output structure will be stored here to be used in the custom vjp
|
||||||
|
// function
|
||||||
|
auto output_structure = std::make_shared<nb::object>();
|
||||||
|
|
||||||
|
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||||
|
// backward pass. Then call it immediately and return the results.
|
||||||
|
auto f = custom_function(
|
||||||
|
InnerFunction(fun_, input_structure, output_structure),
|
||||||
|
make_vjp_function(input_structure, output_structure),
|
||||||
|
make_jvp_function(input_structure),
|
||||||
|
make_vmap_function(input_structure));
|
||||||
|
|
||||||
|
auto outputs = f(input_arrays);
|
||||||
|
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyCustomFunction& set_vjp(nb::callable vjp_fun) {
|
||||||
|
vjp_fun_ = vjp_fun;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyCustomFunction& set_jvp(nb::callable jvp_fun) {
|
||||||
|
jvp_fun_ = jvp_fun;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyCustomFunction& set_vmap(nb::callable vmap_fun) {
|
||||||
|
vmap_fun_ = vmap_fun;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::optional<InnerVJPFunction> make_vjp_function(
|
||||||
|
nb::object input_structure,
|
||||||
|
std::shared_ptr<nb::object> output_structure) {
|
||||||
|
if (!vjp_fun_.has_value()) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<InnerJVPFunction> make_jvp_function(
|
||||||
|
nb::object input_structure) {
|
||||||
|
if (!jvp_fun_.has_value()) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return InnerJVPFunction(*jvp_fun_, input_structure);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<InnerVmapFunction> make_vmap_function(
|
||||||
|
nb::object input_structure) {
|
||||||
|
if (!vmap_fun_.has_value()) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return InnerVmapFunction(*vmap_fun_, input_structure);
|
||||||
|
}
|
||||||
|
|
||||||
|
nb::callable fun_;
|
||||||
|
std::optional<nb::callable> vjp_fun_;
|
||||||
|
std::optional<nb::callable> jvp_fun_;
|
||||||
|
std::optional<nb::callable> vmap_fun_;
|
||||||
|
};
|
||||||
|
|
||||||
void init_transforms(nb::module_& m) {
|
void init_transforms(nb::module_& m) {
|
||||||
|
nb::class_<PyCustomFunction>(
|
||||||
|
m,
|
||||||
|
"custom_function",
|
||||||
|
R"pbdoc(
|
||||||
|
Set up a function for custom gradient and vmap definitions.
|
||||||
|
|
||||||
|
This class is meant to be used as a function decorator. Instances are
|
||||||
|
callables that behave identically to the wrapped function. However, when
|
||||||
|
a function transformation is used (e.g. computing gradients using
|
||||||
|
:func:`value_and_grad`) then the functions defined via :method:`vjp`,
|
||||||
|
:method:`jvp` and :method:`vmap` are used instead of the default
|
||||||
|
transformation.
|
||||||
|
|
||||||
|
Note, all custom transformations are optional. Undefined transformations
|
||||||
|
fall back to the default behaviour.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def f(x, y):
|
||||||
|
return mx.sin(x) * y
|
||||||
|
|
||||||
|
@f.vjp
|
||||||
|
def f_vjp(primals, cotangent, output):
|
||||||
|
x, y = primals
|
||||||
|
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||||
|
|
||||||
|
@f.jvp
|
||||||
|
def f_jvp(primals, tangents):
|
||||||
|
x, y = primals
|
||||||
|
dx, dy = tangents
|
||||||
|
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||||
|
|
||||||
|
@f.vmap
|
||||||
|
def f_vmap(inputs, axes):
|
||||||
|
x, y = inputs
|
||||||
|
ax, ay = axes
|
||||||
|
if ay != ax and ax is not None:
|
||||||
|
y = y.swapaxes(ay, ax)
|
||||||
|
return mx.sin(x) * y, (ax or ay)
|
||||||
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
nb::init<nb::callable>(),
|
||||||
|
"f"_a,
|
||||||
|
nb::sig("def __init__(self, f: callable)"))
|
||||||
|
.def("__call__", &PyCustomFunction::call_impl)
|
||||||
|
.def(
|
||||||
|
"vjp",
|
||||||
|
&PyCustomFunction::set_vjp,
|
||||||
|
"f"_a,
|
||||||
|
nb::sig("def vjp(self, f_vjp: callable)"),
|
||||||
|
R"pbdoc(
|
||||||
|
Define a custom vjp for the wrapped function.
|
||||||
|
|
||||||
|
The vjp function takes three arguments:
|
||||||
|
|
||||||
|
- *primals*: A pytree that contains all the positional arguments to
|
||||||
|
the function. It could be a single array, a tuple of arrays or a
|
||||||
|
full blown tuple of dicts of arrays etc.
|
||||||
|
- *cotangents*: A pytree that matches the structure of the output
|
||||||
|
but contains the cotangents (usually the gradients of the loss
|
||||||
|
function with respect to the outputs).
|
||||||
|
- *outputs*: The outputs of the function to be used to avoid
|
||||||
|
recomputing them for the gradient computation.
|
||||||
|
|
||||||
|
The vjp function should return the same pytree structure as the
|
||||||
|
primals but containing the corresponding computed cotangents.
|
||||||
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
"jvp",
|
||||||
|
&PyCustomFunction::set_jvp,
|
||||||
|
"f"_a,
|
||||||
|
nb::sig("def jvp(self, f_jvp: callable)"),
|
||||||
|
R"pbdoc(
|
||||||
|
Define a custom jvp for the wrapped function.
|
||||||
|
|
||||||
|
The jvp function takes two arguments:
|
||||||
|
|
||||||
|
- *primals*: A pytree that contains all the positional arguments to
|
||||||
|
the function. It could be a single array, a tuple of arrays or a
|
||||||
|
full blown tuple of dicts of arrays etc.
|
||||||
|
- *tangents*: A pytree that matches the structure of the inputs but
|
||||||
|
instead contains the gradients wrt to each input. Tangents could
|
||||||
|
be ``None`` if some inputs don't have an associated gradient.
|
||||||
|
|
||||||
|
The jvp function should return the same pytree structure as the
|
||||||
|
outputs of the function but containing the tangents.
|
||||||
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
"vmap",
|
||||||
|
&PyCustomFunction::set_vmap,
|
||||||
|
"f"_a,
|
||||||
|
nb::sig("def vmap(self, f_vmap: callable)"),
|
||||||
|
R"pbdoc(
|
||||||
|
Define a custom vectorization transformation for the wrapped function.
|
||||||
|
|
||||||
|
The vmap function takes two arguments:
|
||||||
|
|
||||||
|
- *inputs*: A pytree that contains all the positional arguments to
|
||||||
|
the function. It could be a single array, a tuple of arrays or a
|
||||||
|
full blown tuple of dicts of arrays etc.
|
||||||
|
- *axes*: A pytree that matches the structure of the inputs but
|
||||||
|
instead contains the vectorization axis for each input or
|
||||||
|
``None`` if an input is not vectorized.
|
||||||
|
|
||||||
|
The vmap function should return the outputs of the original
|
||||||
|
function but vectorized over the provided axes. It should also
|
||||||
|
return a pytree with the vectorization axes of each output. If some
|
||||||
|
outputs are no longer vectorized, then their vectorization axis
|
||||||
|
should be ``None``.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const nb::args& args) {
|
[](const nb::args& args) {
|
||||||
@ -888,8 +1335,10 @@ void init_transforms(nb::module_& m) {
|
|||||||
const nb::object& outputs,
|
const nb::object& outputs,
|
||||||
bool shapeless) {
|
bool shapeless) {
|
||||||
// Try to get the name
|
// Try to get the name
|
||||||
auto n = fun.attr("__name__");
|
auto n =
|
||||||
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
|
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||||
|
auto name = n.is_none() ? "compiled"
|
||||||
|
: nb::cast<std::string>(fun.attr("__name__"));
|
||||||
|
|
||||||
// Try to get the signature
|
// Try to get the signature
|
||||||
std::ostringstream sig;
|
std::ostringstream sig;
|
||||||
|
@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def test_custom_function(self):
|
||||||
|
# Make a custom function
|
||||||
|
my_exp = mx.custom_function(mx.exp)
|
||||||
|
|
||||||
|
# Ensure everything works
|
||||||
|
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||||
|
self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0))))
|
||||||
|
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||||
|
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||||
|
self.assertTrue(mx.allclose(ex, dex))
|
||||||
|
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||||
|
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||||
|
|
||||||
|
# Ensure that the vjp is being overriden but everything else still
|
||||||
|
# works.
|
||||||
|
@my_exp.vjp
|
||||||
|
def my_exp_vjp(x, dx, ex):
|
||||||
|
return mx.ones_like(x) * 42
|
||||||
|
|
||||||
|
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||||
|
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||||
|
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||||
|
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||||
|
self.assertTrue(mx.allclose(ex, dex))
|
||||||
|
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||||
|
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||||
|
|
||||||
|
# Ensure that setting the jvp and vmap also works.
|
||||||
|
@my_exp.jvp
|
||||||
|
def my_exp_jvp(x, dx):
|
||||||
|
return mx.ones_like(x) * 7 * dx
|
||||||
|
|
||||||
|
@my_exp.vmap
|
||||||
|
def my_exp_vmap(x, axis):
|
||||||
|
return mx.ones_like(x) * 3, axis
|
||||||
|
|
||||||
|
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||||
|
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||||
|
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||||
|
self.assertTrue(mx.allclose(dex, mx.array(7.0)))
|
||||||
|
self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0))))
|
||||||
|
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||||
|
self.assertTrue(mx.allclose(ex, 3 * mx.ones(10)))
|
||||||
|
|
||||||
|
# Test pytrees
|
||||||
|
@mx.custom_function
|
||||||
|
def my_double(params):
|
||||||
|
return {"out": 2 * params["x"] * params["y"]}
|
||||||
|
|
||||||
|
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||||
|
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2))
|
||||||
|
self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2))
|
||||||
|
|
||||||
|
@my_double.vjp
|
||||||
|
def random_grads(primals, cotangents, outputs):
|
||||||
|
return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])}
|
||||||
|
|
||||||
|
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||||
|
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(dy["x"], mx.zeros(2)))
|
||||||
|
self.assertTrue(mx.allclose(dy["y"], mx.ones(2)))
|
||||||
|
|
||||||
|
def outer_f(a, b):
|
||||||
|
return my_double({"x": a, "y": b})["out"]
|
||||||
|
|
||||||
|
inputs = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||||
|
tans = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||||
|
out1, dout1 = mx.jvp(outer_f, inputs, tans)
|
||||||
|
|
||||||
|
@my_double.jvp
|
||||||
|
def random_grads(primals, tangents):
|
||||||
|
return {
|
||||||
|
"out": 2 * primals["x"] * tangents["y"]
|
||||||
|
+ 2 * primals["y"] * tangents["x"]
|
||||||
|
+ 1
|
||||||
|
}
|
||||||
|
|
||||||
|
out2, dout2 = mx.jvp(outer_f, inputs, tans)
|
||||||
|
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
||||||
|
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user