mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom transforms (#1246)
This commit is contained in:
committed by
GitHub
parent
a3c287354f
commit
5c1fa64fb0
@@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation.
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::queue<array> tape;
|
||||
@@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
std::vector<array> vjps;
|
||||
{
|
||||
detail::RetainGraph retain;
|
||||
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
}
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
@@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
|
||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
|
||||
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
|
||||
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
|
||||
return fun;
|
||||
}
|
||||
|
||||
return [fun = std::move(fun),
|
||||
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
|
||||
fun_vjp = std::move(fun_vjp),
|
||||
fun_jvp = std::move(fun_jvp),
|
||||
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
|
||||
// Compute the outputs
|
||||
auto outputs = fun(args);
|
||||
for (auto& out : outputs) {
|
||||
@@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
return array::make_arrays(
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
std::make_shared<CustomTransforms>(
|
||||
to_stream(s),
|
||||
outputs.size(),
|
||||
|
||||
// We use the passed vjp function or compute it from the inputs and
|
||||
// passed cotangents. Note that this may be less efficient than
|
||||
// using `fun` directly because we may not be able to fully reuse
|
||||
// the outputs of the forward pass.
|
||||
fun_vjp.value_or(
|
||||
[fun](auto primals, auto cotangents, auto outputs) {
|
||||
auto [__, vjps] = vjp(fun, primals, cotangents);
|
||||
return vjps;
|
||||
}),
|
||||
|
||||
// We use the passed jvp function or compute it from the primals
|
||||
// and tangents. Similarly we can't take full advantage of the
|
||||
// argnums so it is best to use `fun` directly if we don't need a
|
||||
// custom transform.
|
||||
//
|
||||
// TODO: Use stop_gradient to make full use of argnums and not
|
||||
// waste computation.
|
||||
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
}
|
||||
}
|
||||
auto [__, jvps] = jvp(fun, primals, all_tangents);
|
||||
return jvps;
|
||||
}),
|
||||
|
||||
// Same as above, we use the passed vmap function or we compute it
|
||||
// from `fun`. The output axes is selected to be all 0s which again
|
||||
// may be suboptimal but the only thing we can do without any
|
||||
// information for `fun`.
|
||||
fun_vmap.value_or(
|
||||
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
|
||||
-> std::pair<std::vector<array>, std::vector<int>> {
|
||||
std::vector<int> out_axes(out_size, 0);
|
||||
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
|
||||
})),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
||||
auto vjp_fun = [fun](
|
||||
|
||||
Reference in New Issue
Block a user