diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..57105a74c 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include #include @@ -8,7 +9,9 @@ #include #include #include +#include +#include "mlx/array.h" #include "mlx/backend/cpu/eval.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/fence.h" @@ -1040,4 +1043,85 @@ std::function(const std::vector&)> checkpoint( return custom_vjp(fun, vjp_fun); } +std::pair, std::vector> jacfwd( + const std::function(const std::vector&)>& fun, + const std::vector& primals) { + detail::InTracing in_tracing{false, true}; + auto outputs = fun(primals); + + if (primals.empty()) { + return {outputs, {}}; + } + + size_t total_input_size = 0; + std::vector input_sizes; + for (const auto& p : primals) { + size_t size = p.size(); + total_input_size += size; + input_sizes.push_back(size); + } + + size_t total_output_size = 0; + std::vector output_sizes; + for (const auto& o : outputs) { + size_t size = o.size(); + total_output_size += size; + output_sizes.push_back(size); + } + + std::vector jacobian; + Stream s = default_stream(default_device()); + for (size_t i = 0; i < outputs.size(); ++i) { + std::vector shape = { + static_cast(output_sizes[i]), static_cast(total_input_size)}; + jacobian.push_back(zeros(shape, outputs[i].dtype(), s)); + } + + size_t input_offset = 0; + for (size_t i = 0; i < primals.size(); ++i) { + for (size_t j = 0; j < input_sizes[i]; ++j) { + std::vector tangents; + for (size_t k = 0; k < primals.size(); ++k) { + array t = zeros_like(primals[k], s); + if (k == i) { + if (primals[i].size() == 1) { + t = full({1}, 1.0f, s); + } else { + t.data()[j] = 1.0f; + } + } + tangents.push_back(t); + } + + auto [_, jvp_out] = jvp(fun, primals, tangents); + + size_t output_offset = 0; + for (size_t k = 0; k < jvp_out.size(); ++k) { + array flat_jvp = + reshape(jvp_out[k], {static_cast(output_sizes[k])}, s); + float* jvp_data = flat_jvp.data(); + float* jac_data = jacobian[k].data(); + + for (size_t m = 0; m < output_sizes[k]; ++m) { + jac_data[m * total_input_size + input_offset] = jvp_data[m]; + } + output_offset += output_sizes[k]; + } + input_offset++; + } + } + + return {outputs, jacobian}; +} + +// For scalar functions +std::pair jacfwd( + const std::function& fun, + const array& primal) { + auto vec_fun = [fun](const std::vector& inputs) { + return std::vector{fun(inputs[0])}; + }; + auto [outputs, jacobian] = jacfwd(vec_fun, {primal}); + return {outputs[0], jacobian[0]}; +} } // namespace mlx::core