mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
Add jacfwd
This commit is contained in:
parent
e8ac6bd2f5
commit
a7a6c49909
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
@ -8,7 +9,9 @@
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/fence.h"
|
||||
@ -1040,4 +1043,85 @@ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
return custom_vjp(fun, vjp_fun);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> jacfwd(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals) {
|
||||
detail::InTracing in_tracing{false, true};
|
||||
auto outputs = fun(primals);
|
||||
|
||||
if (primals.empty()) {
|
||||
return {outputs, {}};
|
||||
}
|
||||
|
||||
size_t total_input_size = 0;
|
||||
std::vector<size_t> input_sizes;
|
||||
for (const auto& p : primals) {
|
||||
size_t size = p.size();
|
||||
total_input_size += size;
|
||||
input_sizes.push_back(size);
|
||||
}
|
||||
|
||||
size_t total_output_size = 0;
|
||||
std::vector<size_t> output_sizes;
|
||||
for (const auto& o : outputs) {
|
||||
size_t size = o.size();
|
||||
total_output_size += size;
|
||||
output_sizes.push_back(size);
|
||||
}
|
||||
|
||||
std::vector<array> jacobian;
|
||||
Stream s = default_stream(default_device());
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
std::vector<int> shape = {
|
||||
static_cast<int>(output_sizes[i]), static_cast<int>(total_input_size)};
|
||||
jacobian.push_back(zeros(shape, outputs[i].dtype(), s));
|
||||
}
|
||||
|
||||
size_t input_offset = 0;
|
||||
for (size_t i = 0; i < primals.size(); ++i) {
|
||||
for (size_t j = 0; j < input_sizes[i]; ++j) {
|
||||
std::vector<array> tangents;
|
||||
for (size_t k = 0; k < primals.size(); ++k) {
|
||||
array t = zeros_like(primals[k], s);
|
||||
if (k == i) {
|
||||
if (primals[i].size() == 1) {
|
||||
t = full({1}, 1.0f, s);
|
||||
} else {
|
||||
t.data<float>()[j] = 1.0f;
|
||||
}
|
||||
}
|
||||
tangents.push_back(t);
|
||||
}
|
||||
|
||||
auto [_, jvp_out] = jvp(fun, primals, tangents);
|
||||
|
||||
size_t output_offset = 0;
|
||||
for (size_t k = 0; k < jvp_out.size(); ++k) {
|
||||
array flat_jvp =
|
||||
reshape(jvp_out[k], {static_cast<int>(output_sizes[k])}, s);
|
||||
float* jvp_data = flat_jvp.data<float>();
|
||||
float* jac_data = jacobian[k].data<float>();
|
||||
|
||||
for (size_t m = 0; m < output_sizes[k]; ++m) {
|
||||
jac_data[m * total_input_size + input_offset] = jvp_data[m];
|
||||
}
|
||||
output_offset += output_sizes[k];
|
||||
}
|
||||
input_offset++;
|
||||
}
|
||||
}
|
||||
|
||||
return {outputs, jacobian};
|
||||
}
|
||||
|
||||
// For scalar functions
|
||||
std::pair<array, array> jacfwd(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal) {
|
||||
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
};
|
||||
auto [outputs, jacobian] = jacfwd(vec_fun, {primal});
|
||||
return {outputs[0], jacobian[0]};
|
||||
}
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user