mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
Add jacfwd
This commit is contained in:
parent
e8ac6bd2f5
commit
a7a6c49909
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
#include <functional>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
@ -8,7 +9,9 @@
|
|||||||
#include <stack>
|
#include <stack>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cpu/eval.h"
|
#include "mlx/backend/cpu/eval.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
@ -1040,4 +1043,85 @@ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
|||||||
return custom_vjp(fun, vjp_fun);
|
return custom_vjp(fun, vjp_fun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> jacfwd(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals) {
|
||||||
|
detail::InTracing in_tracing{false, true};
|
||||||
|
auto outputs = fun(primals);
|
||||||
|
|
||||||
|
if (primals.empty()) {
|
||||||
|
return {outputs, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t total_input_size = 0;
|
||||||
|
std::vector<size_t> input_sizes;
|
||||||
|
for (const auto& p : primals) {
|
||||||
|
size_t size = p.size();
|
||||||
|
total_input_size += size;
|
||||||
|
input_sizes.push_back(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t total_output_size = 0;
|
||||||
|
std::vector<size_t> output_sizes;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
size_t size = o.size();
|
||||||
|
total_output_size += size;
|
||||||
|
output_sizes.push_back(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> jacobian;
|
||||||
|
Stream s = default_stream(default_device());
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
std::vector<int> shape = {
|
||||||
|
static_cast<int>(output_sizes[i]), static_cast<int>(total_input_size)};
|
||||||
|
jacobian.push_back(zeros(shape, outputs[i].dtype(), s));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t input_offset = 0;
|
||||||
|
for (size_t i = 0; i < primals.size(); ++i) {
|
||||||
|
for (size_t j = 0; j < input_sizes[i]; ++j) {
|
||||||
|
std::vector<array> tangents;
|
||||||
|
for (size_t k = 0; k < primals.size(); ++k) {
|
||||||
|
array t = zeros_like(primals[k], s);
|
||||||
|
if (k == i) {
|
||||||
|
if (primals[i].size() == 1) {
|
||||||
|
t = full({1}, 1.0f, s);
|
||||||
|
} else {
|
||||||
|
t.data<float>()[j] = 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tangents.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [_, jvp_out] = jvp(fun, primals, tangents);
|
||||||
|
|
||||||
|
size_t output_offset = 0;
|
||||||
|
for (size_t k = 0; k < jvp_out.size(); ++k) {
|
||||||
|
array flat_jvp =
|
||||||
|
reshape(jvp_out[k], {static_cast<int>(output_sizes[k])}, s);
|
||||||
|
float* jvp_data = flat_jvp.data<float>();
|
||||||
|
float* jac_data = jacobian[k].data<float>();
|
||||||
|
|
||||||
|
for (size_t m = 0; m < output_sizes[k]; ++m) {
|
||||||
|
jac_data[m * total_input_size + input_offset] = jvp_data[m];
|
||||||
|
}
|
||||||
|
output_offset += output_sizes[k];
|
||||||
|
}
|
||||||
|
input_offset++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {outputs, jacobian};
|
||||||
|
}
|
||||||
|
|
||||||
|
// For scalar functions
|
||||||
|
std::pair<array, array> jacfwd(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
const array& primal) {
|
||||||
|
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{fun(inputs[0])};
|
||||||
|
};
|
||||||
|
auto [outputs, jacobian] = jacfwd(vec_fun, {primal});
|
||||||
|
return {outputs[0], jacobian[0]};
|
||||||
|
}
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user