mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom transforms (#1246)
This commit is contained in:
committed by
GitHub
parent
a3c287354f
commit
5c1fa64fb0
@@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
|
||||
return {{cosh(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> CustomVJP::vjp(
|
||||
std::vector<array> CustomTransforms::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
|
||||
// Extract the inputs to the VJP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute all the vjps
|
||||
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
|
||||
for (const auto& cot : cotangents) {
|
||||
all_vjps.emplace_back(cot);
|
||||
}
|
||||
|
||||
// Select the vjps requested
|
||||
std::vector<array> vjps;
|
||||
vjps.reserve(argnums.size());
|
||||
for (auto arg : argnums) {
|
||||
@@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> CustomTransforms::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Extract the inputs to the JVP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute the jvps
|
||||
return jvp_fun_(inputs, tangents, argnums);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
|
||||
const std::vector<array>& inputs_,
|
||||
const std::vector<int>& axes_) {
|
||||
// Extract the inputs to the vmap function
|
||||
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
|
||||
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
|
||||
return vmap_fun_(inputs, axes);
|
||||
}
|
||||
|
||||
std::vector<array> Depends::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
Reference in New Issue
Block a user