Add jacfwd

This commit is contained in:
paramthakkar123 2025-04-23 18:47:19 +05:30
parent e8ac6bd2f5
commit a7a6c49909

View File

@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <deque> #include <deque>
#include <functional>
#include <future> #include <future>
#include <numeric> #include <numeric>
#include <set> #include <set>
@ -8,7 +9,9 @@
#include <stack> #include <stack>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "mlx/array.h"
#include "mlx/backend/cpu/eval.h" #include "mlx/backend/cpu/eval.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/fence.h" #include "mlx/fence.h"
@ -1040,4 +1043,85 @@ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
return custom_vjp(fun, vjp_fun); return custom_vjp(fun, vjp_fun);
} }
std::pair<std::vector<array>, std::vector<array>> jacfwd(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals) {
detail::InTracing in_tracing{false, true};
auto outputs = fun(primals);
if (primals.empty()) {
return {outputs, {}};
}
size_t total_input_size = 0;
std::vector<size_t> input_sizes;
for (const auto& p : primals) {
size_t size = p.size();
total_input_size += size;
input_sizes.push_back(size);
}
size_t total_output_size = 0;
std::vector<size_t> output_sizes;
for (const auto& o : outputs) {
size_t size = o.size();
total_output_size += size;
output_sizes.push_back(size);
}
std::vector<array> jacobian;
Stream s = default_stream(default_device());
for (size_t i = 0; i < outputs.size(); ++i) {
std::vector<int> shape = {
static_cast<int>(output_sizes[i]), static_cast<int>(total_input_size)};
jacobian.push_back(zeros(shape, outputs[i].dtype(), s));
}
size_t input_offset = 0;
for (size_t i = 0; i < primals.size(); ++i) {
for (size_t j = 0; j < input_sizes[i]; ++j) {
std::vector<array> tangents;
for (size_t k = 0; k < primals.size(); ++k) {
array t = zeros_like(primals[k], s);
if (k == i) {
if (primals[i].size() == 1) {
t = full({1}, 1.0f, s);
} else {
t.data<float>()[j] = 1.0f;
}
}
tangents.push_back(t);
}
auto [_, jvp_out] = jvp(fun, primals, tangents);
size_t output_offset = 0;
for (size_t k = 0; k < jvp_out.size(); ++k) {
array flat_jvp =
reshape(jvp_out[k], {static_cast<int>(output_sizes[k])}, s);
float* jvp_data = flat_jvp.data<float>();
float* jac_data = jacobian[k].data<float>();
for (size_t m = 0; m < output_sizes[k]; ++m) {
jac_data[m * total_input_size + input_offset] = jvp_data[m];
}
output_offset += output_sizes[k];
}
input_offset++;
}
}
return {outputs, jacobian};
}
// For scalar functions
std::pair<array, array> jacfwd(
const std::function<array(const array&)>& fun,
const array& primal) {
auto vec_fun = [fun](const std::vector<array>& inputs) {
return std::vector<array>{fun(inputs[0])};
};
auto [outputs, jacobian] = jacfwd(vec_fun, {primal});
return {outputs[0], jacobian[0]};
}
} // namespace mlx::core } // namespace mlx::core