mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
d8fabaa12b
commit
a2bf7693dd
@ -51,29 +51,31 @@ std::tuple<array, array, int> vmap_binary_op(
|
||||
} // namespace
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's jvp not implemented.");
|
||||
};
|
||||
|
||||
std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
throw std::invalid_argument("Primitive's vjp not implemented.");
|
||||
};
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
||||
};
|
||||
|
||||
std::vector<array> Abs::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -106,7 +108,8 @@ std::vector<array> Add::jvp(
|
||||
std::vector<array> Add::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
if (argnums.size() == 1) {
|
||||
return cotangents;
|
||||
} else {
|
||||
@ -131,7 +134,8 @@ bool Arange::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> ArcCos::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -158,7 +162,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcCos::vmap(
|
||||
std::vector<array> ArcCosh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -184,7 +189,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcCosh::vmap(
|
||||
std::vector<array> ArcSin::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -210,7 +216,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcSin::vmap(
|
||||
std::vector<array> ArcSinh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -236,7 +243,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcSinh::vmap(
|
||||
std::vector<array> ArcTan::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -262,7 +270,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
|
||||
std::vector<array> ArcTanh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -322,7 +331,8 @@ bool ArgSort::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> AsType::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
if (cotangents[0].dtype() != dtype_) {
|
||||
throw std::invalid_argument(
|
||||
"[astype] Type of cotangentsgent does not much primal output type.");
|
||||
@ -351,7 +361,8 @@ bool AsType::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> AsStrided::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1);
|
||||
|
||||
// Extract the sizes and cast them to ints
|
||||
@ -395,7 +406,8 @@ bool AsStrided::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1);
|
||||
|
||||
// Reduce cotangents to the shape of the primal
|
||||
@ -445,7 +457,8 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Ceil::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -469,7 +482,8 @@ std::pair<std::vector<array>, std::vector<int>> Ceil::vmap(
|
||||
std::vector<array> Concatenate::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> start(cotan.ndim(), 0);
|
||||
std::vector<int> stop = cotan.shape();
|
||||
@ -544,7 +558,8 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Convolution::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
@ -661,7 +676,8 @@ bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Copy::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return cotangents;
|
||||
@ -687,7 +703,8 @@ std::pair<std::vector<array>, std::vector<int>> Copy::vmap(
|
||||
std::vector<array> Cos::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return {jvp(primals, cotangents, argnums)};
|
||||
}
|
||||
|
||||
@ -712,7 +729,8 @@ std::pair<std::vector<array>, std::vector<int>> Cos::vmap(
|
||||
std::vector<array> Cosh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -736,7 +754,8 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
|
||||
std::vector<array> Divide::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
@ -756,7 +775,8 @@ std::vector<array> Divide::vjp(
|
||||
std::vector<array> DivMod::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -812,7 +832,8 @@ std::pair<std::vector<array>, std::vector<int>> Divide::vmap(
|
||||
std::vector<array> Remainder::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
@ -865,7 +886,8 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
|
||||
std::vector<array> Equal::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -884,7 +906,8 @@ std::vector<array> Equal::jvp(
|
||||
std::vector<array> Erf::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -913,8 +936,13 @@ std::pair<std::vector<array>, std::vector<int>> Erf::vmap(
|
||||
std::vector<array> ErfInv::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto dtype = primals[0].dtype();
|
||||
auto scale =
|
||||
multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream());
|
||||
return {
|
||||
multiply(scale, exp(square(outputs[0], stream()), stream()), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> ErfInv::jvp(
|
||||
@ -942,8 +970,9 @@ std::pair<std::vector<array>, std::vector<int>> ErfInv::vmap(
|
||||
std::vector<array> Exp::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
return {multiply(cotangents[0], outputs[0], stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Exp::jvp(
|
||||
@ -997,7 +1026,8 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
std::vector<array> FFT::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
auto& in = primals[0];
|
||||
@ -1050,7 +1080,8 @@ std::vector<array> FFT::jvp(
|
||||
std::vector<array> Floor::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1074,7 +1105,8 @@ std::pair<std::vector<array>, std::vector<int>> Floor::vmap(
|
||||
std::vector<array> Full::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return {multiply(cotangents[0], primals[0], stream())};
|
||||
@ -1155,7 +1187,8 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
|
||||
std::vector<array> Gather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[gather] Cannot calculate VJP with respect to indices.");
|
||||
@ -1192,7 +1225,8 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
||||
std::vector<array> Greater::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -1218,7 +1252,8 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
|
||||
std::vector<array> GreaterEqual::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -1244,7 +1279,8 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||
std::vector<array> Less::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -1270,7 +1306,8 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
|
||||
std::vector<array> LessEqual::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -1289,7 +1326,8 @@ std::vector<array> LessEqual::jvp(
|
||||
std::vector<array> Log::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1325,7 +1363,8 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
|
||||
std::vector<array> Log1p::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1351,7 +1390,8 @@ std::pair<std::vector<array>, std::vector<int>> Log1p::vmap(
|
||||
std::vector<array> LogicalNot::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1375,7 +1415,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalNot::vmap(
|
||||
std::vector<array> LogicalAnd::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
@ -1406,7 +1447,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalAnd::vmap(
|
||||
std::vector<array> LogicalOr::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
@ -1438,7 +1480,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalOr::vmap(
|
||||
std::vector<array> LogAddExp::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto a = primals[0];
|
||||
auto b = primals[1];
|
||||
auto s = sigmoid(subtract(a, b, stream()), stream());
|
||||
@ -1483,7 +1526,8 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
|
||||
std::vector<array> Matmul::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
std::vector<int> reorder(cotan.ndim());
|
||||
@ -1506,7 +1550,8 @@ std::vector<array> Matmul::vjp(
|
||||
std::vector<array> Maximum::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& a = primals[0];
|
||||
auto& b = primals[1];
|
||||
std::vector<array> vjps;
|
||||
@ -1547,7 +1592,8 @@ std::pair<std::vector<array>, std::vector<int>> Maximum::vmap(
|
||||
std::vector<array> Minimum::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& a = primals[0];
|
||||
auto& b = primals[1];
|
||||
std::vector<array> vjps;
|
||||
@ -1601,7 +1647,8 @@ std::vector<array> Multiply::jvp(
|
||||
std::vector<array> Multiply::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
|
||||
@ -1619,7 +1666,8 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
|
||||
std::vector<array> Negative::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1650,7 +1698,8 @@ std::pair<std::vector<array>, std::vector<int>> NotEqual::vmap(
|
||||
std::vector<array> NotEqual::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
@ -1669,7 +1718,8 @@ std::vector<array> NotEqual::jvp(
|
||||
std::vector<array> Pad::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||
|
||||
auto& cotan = cotangents[0];
|
||||
@ -1717,7 +1767,8 @@ bool Pad::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Partition::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -1749,22 +1800,15 @@ bool Partition::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Power::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(multiply(
|
||||
power(
|
||||
primals[0],
|
||||
subtract(primals[1], array(1, primals[0].dtype()), stream()),
|
||||
stream()),
|
||||
primals[1],
|
||||
stream()));
|
||||
outputs[0], divide(primals[1], primals[0], stream()), stream()));
|
||||
} else {
|
||||
vjps.push_back(multiply(
|
||||
log(primals[0], stream()),
|
||||
power(primals[0], primals[1], stream()),
|
||||
stream()));
|
||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||
}
|
||||
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
|
||||
}
|
||||
@ -1775,12 +1819,13 @@ std::vector<array> Power::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto jvp = vjp(primals, {tangents[0]}, {argnums[0]});
|
||||
auto output = power(primals[0], primals[1], stream());
|
||||
auto grads = vjp(primals, tangents, argnums, {output});
|
||||
if (argnums.size() > 1) {
|
||||
jvp[0] =
|
||||
add(jvp[0], vjp(primals, {tangents[1]}, {argnums[1]})[0], stream());
|
||||
return {add(grads[0], grads[1], stream())};
|
||||
} else {
|
||||
return grads;
|
||||
}
|
||||
return jvp;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Power::vmap(
|
||||
@ -1799,7 +1844,8 @@ std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
|
||||
std::vector<array> QuantizedMatmul::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
|
||||
// We rely on the fact that w is always 2D so transpose is simple
|
||||
@ -1902,7 +1948,8 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
||||
std::vector<array> Reshape::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
@ -1927,7 +1974,8 @@ bool Reshape::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto in = primals[0];
|
||||
|
||||
std::vector<int> shape = in.shape();
|
||||
@ -1997,15 +2045,10 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
|
||||
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
|
||||
array (*op)(const array&, const std::vector<int>&, bool, StreamOrDevice);
|
||||
|
||||
if (reduce_type_ == Reduce::Min) {
|
||||
op = min;
|
||||
} else {
|
||||
op = max;
|
||||
auto out = outputs[0];
|
||||
if (out.ndim() != in.ndim()) {
|
||||
out = expand_dims(out, axes_, stream());
|
||||
}
|
||||
|
||||
auto out = op(in, axes_, true, stream());
|
||||
auto mask = equal(in, out, stream());
|
||||
auto normalizer = sum(mask, axes_, true, stream());
|
||||
auto cotan_reshape = reshape(cotan, shape, stream());
|
||||
@ -2032,7 +2075,8 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Round::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2076,7 +2120,8 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
||||
std::vector<array> Scan::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
|
||||
@ -2084,7 +2129,7 @@ std::vector<array> Scan::vjp(
|
||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||
} else if (reduce_type_ == Scan::Prod) {
|
||||
// TODO: Make it numerically stable when we introduce where()
|
||||
auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream());
|
||||
auto prod = outputs[0];
|
||||
auto partial_grads = multiply(prod, cotangents[0], stream());
|
||||
auto accum_grads =
|
||||
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
|
||||
@ -2125,7 +2170,8 @@ bool Scatter::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Scatter::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
case Scatter::Sum:
|
||||
@ -2137,23 +2183,11 @@ std::vector<array> Scatter::vjp(
|
||||
"[scatter] VJP not implemented for scatter_prod");
|
||||
}
|
||||
|
||||
const array& result = outputs[0];
|
||||
const array& values = primals[0];
|
||||
const array& updates = primals.back();
|
||||
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
||||
|
||||
// Store result of scatter if needed for reuse in vjp
|
||||
auto get_result = [&]() {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::Max:
|
||||
return scatter_max(values, indices, updates, axes_, stream());
|
||||
case Scatter::Min:
|
||||
return scatter_min(values, indices, updates, axes_, stream());
|
||||
default:
|
||||
return array({});
|
||||
}
|
||||
};
|
||||
array result = get_result();
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto num : argnums) {
|
||||
// Gradient wrt to the input array
|
||||
@ -2232,8 +2266,12 @@ std::vector<array> Scatter::jvp(
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& s = outputs[0];
|
||||
auto sprime =
|
||||
multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());
|
||||
return {multiply(cotangents[0], sprime, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::jvp(
|
||||
@ -2259,7 +2297,8 @@ std::pair<std::vector<array>, std::vector<int>> Sigmoid::vmap(
|
||||
std::vector<array> Sign::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2283,7 +2322,8 @@ std::pair<std::vector<array>, std::vector<int>> Sign::vmap(
|
||||
std::vector<array> Sin::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2307,7 +2347,8 @@ std::pair<std::vector<array>, std::vector<int>> Sin::vmap(
|
||||
std::vector<array> Sinh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2345,7 +2386,8 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
|
||||
std::vector<array> Slice::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 1);
|
||||
|
||||
@ -2444,8 +2486,15 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
std::vector<array> Softmax::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
auto& s = outputs[0];
|
||||
auto sv = multiply(s, cotangents[0], stream());
|
||||
return {subtract(
|
||||
sv,
|
||||
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
|
||||
}
|
||||
|
||||
std::vector<array> Softmax::jvp(
|
||||
@ -2473,7 +2522,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
||||
std::vector<array> Sort::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2503,7 +2553,8 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
||||
std::vector<array> Split::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return {concatenate(cotangents, axis_, stream())};
|
||||
}
|
||||
|
||||
@ -2522,7 +2573,8 @@ bool Split::is_equivalent(const Primitive& other) const {
|
||||
std::vector<array> Square::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2549,29 +2601,34 @@ std::pair<std::vector<array>, std::vector<int>> Square::vmap(
|
||||
std::vector<array> Sqrt::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
auto dtype = primals[0].dtype();
|
||||
if (recip_) {
|
||||
auto one_over_x_root_x = divide(outputs[0], primals[0], stream());
|
||||
return {multiply(
|
||||
multiply(array(-0.5, dtype), cotangents[0], stream()),
|
||||
one_over_x_root_x,
|
||||
stream())};
|
||||
} else {
|
||||
return {divide(
|
||||
multiply(array(0.5, dtype), cotangents[0], stream()),
|
||||
outputs[0],
|
||||
stream())};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Sqrt::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
auto dtype = primals[0].dtype();
|
||||
if (recip_) {
|
||||
auto one_over_x_root_x =
|
||||
divide(rsqrt(primals[0], stream()), primals[0], stream());
|
||||
return {multiply(
|
||||
multiply(array(-0.5, dtype), tangents[0], stream()),
|
||||
one_over_x_root_x,
|
||||
stream())};
|
||||
return vjp(primals, tangents, argnums, {rsqrt(primals[0], stream())});
|
||||
} else {
|
||||
return vjp(primals, tangents, argnums, {sqrt(primals[0], stream())});
|
||||
}
|
||||
return {divide(
|
||||
multiply(array(0.5, dtype), tangents[0], stream()),
|
||||
sqrt(primals[0], stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Sqrt::vmap(
|
||||
@ -2599,7 +2656,8 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
||||
std::vector<array> Subtract::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto vjp = cotangents[0];
|
||||
@ -2636,7 +2694,8 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
|
||||
std::vector<array> Tan::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2661,7 +2720,8 @@ std::pair<std::vector<array>, std::vector<int>> Tan::vmap(
|
||||
std::vector<array> Tanh::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
@ -2686,7 +2746,8 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
|
||||
std::vector<array> Transpose::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
std::vector<int> iaxes(axes_.size());
|
||||
|
@ -21,7 +21,8 @@
|
||||
std::vector<array> vjp( \
|
||||
const std::vector<array>& primals, \
|
||||
const std::vector<array>& cotangents, \
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums, \
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
#define DEFINE_PRINT(PRIMITIVE) \
|
||||
void print(std::ostream& os) override { \
|
||||
@ -78,7 +79,8 @@ class Primitive {
|
||||
virtual std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums);
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs);
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@ -464,7 +466,8 @@ class Convolution : public UnaryPrimitive {
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(Convolution)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
@ -919,7 +922,8 @@ class Matmul : public UnaryPrimitive {
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(Matmul)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
@ -1153,7 +1157,8 @@ class Reduce : public UnaryPrimitive {
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (reduce_type_) {
|
||||
|
@ -443,7 +443,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums);
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
|
@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase):
|
||||
elif not isinstance(expected, mx.array):
|
||||
expected = mx.array(expected)
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
else:
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
|
@ -1005,7 +1005,8 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
index_y = mx.array([3, 3, 1, 2])
|
||||
u = mx.random.uniform(shape=(4,))
|
||||
a = a.at[index_x, index_y].add(u)
|
||||
self.assertEqual(a.sum().item(), u.sum().item())
|
||||
self.assertTrue(mx.allclose(a.sum(), u.sum()))
|
||||
self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5)
|
||||
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
|
||||
|
||||
# Test all array.at ops
|
||||
|
Loading…
Reference in New Issue
Block a user