Primitive's VJP takes outputs as input (#475)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-01-16 19:03:53 -08:00 committed by GitHub
parent d8fabaa12b
commit a2bf7693dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 205 additions and 136 deletions

View File

@ -51,29 +51,31 @@ std::tuple<array, array, int> vmap_binary_op(
} // namespace
std::vector<array> Primitive::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's jvp not implemented.");
};
std::vector<array> Primitive::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
throw std::invalid_argument("Primitive's vjp not implemented.");
};
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's vmap not implemented.");
};
std::vector<array> Abs::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -106,7 +108,8 @@ std::vector<array> Add::jvp(
std::vector<array> Add::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
if (argnums.size() == 1) {
return cotangents;
} else {
@ -131,7 +134,8 @@ bool Arange::is_equivalent(const Primitive& other) const {
std::vector<array> ArcCos::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -158,7 +162,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcCos::vmap(
std::vector<array> ArcCosh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -184,7 +189,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcCosh::vmap(
std::vector<array> ArcSin::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -210,7 +216,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcSin::vmap(
std::vector<array> ArcSinh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -236,7 +243,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcSinh::vmap(
std::vector<array> ArcTan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -262,7 +270,8 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
std::vector<array> ArcTanh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -322,7 +331,8 @@ bool ArgSort::is_equivalent(const Primitive& other) const {
std::vector<array> AsType::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
if (cotangents[0].dtype() != dtype_) {
throw std::invalid_argument(
"[astype] Type of cotangentsgent does not much primal output type.");
@ -351,7 +361,8 @@ bool AsType::is_equivalent(const Primitive& other) const {
std::vector<array> AsStrided::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(argnums.size() == 1);
// Extract the sizes and cast them to ints
@ -395,7 +406,8 @@ bool AsStrided::is_equivalent(const Primitive& other) const {
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(argnums.size() == 1);
// Reduce cotangents to the shape of the primal
@ -445,7 +457,8 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
std::vector<array> Ceil::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -469,7 +482,8 @@ std::pair<std::vector<array>, std::vector<int>> Ceil::vmap(
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
@ -544,7 +558,8 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> grads;
@ -661,7 +676,8 @@ bool Convolution::is_equivalent(const Primitive& other) const {
std::vector<array> Copy::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return cotangents;
@ -687,7 +703,8 @@ std::pair<std::vector<array>, std::vector<int>> Copy::vmap(
std::vector<array> Cos::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return {jvp(primals, cotangents, argnums)};
}
@ -712,7 +729,8 @@ std::pair<std::vector<array>, std::vector<int>> Cos::vmap(
std::vector<array> Cosh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -736,7 +754,8 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
std::vector<array> Divide::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
@ -756,7 +775,8 @@ std::vector<array> Divide::vjp(
std::vector<array> DivMod::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -812,7 +832,8 @@ std::pair<std::vector<array>, std::vector<int>> Divide::vmap(
std::vector<array> Remainder::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
@ -865,7 +886,8 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
std::vector<array> Equal::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -884,7 +906,8 @@ std::vector<array> Equal::jvp(
std::vector<array> Erf::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -913,8 +936,13 @@ std::pair<std::vector<array>, std::vector<int>> Erf::vmap(
std::vector<array> ErfInv::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto dtype = primals[0].dtype();
auto scale =
multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream());
return {
multiply(scale, exp(square(outputs[0], stream()), stream()), stream())};
}
std::vector<array> ErfInv::jvp(
@ -942,8 +970,9 @@ std::pair<std::vector<array>, std::vector<int>> ErfInv::vmap(
std::vector<array> Exp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
return {multiply(cotangents[0], outputs[0], stream())};
}
std::vector<array> Exp::jvp(
@ -997,7 +1026,8 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
std::vector<array> FFT::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
auto& in = primals[0];
@ -1050,7 +1080,8 @@ std::vector<array> FFT::jvp(
std::vector<array> Floor::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1074,7 +1105,8 @@ std::pair<std::vector<array>, std::vector<int>> Floor::vmap(
std::vector<array> Full::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(cotangents[0], primals[0], stream())};
@ -1155,7 +1187,8 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
std::vector<array> Gather::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
if (argnums.size() > 1 || argnums[0] != 0) {
throw std::invalid_argument(
"[gather] Cannot calculate VJP with respect to indices.");
@ -1192,7 +1225,8 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
std::vector<array> Greater::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -1218,7 +1252,8 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
std::vector<array> GreaterEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -1244,7 +1279,8 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
std::vector<array> Less::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -1270,7 +1306,8 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
std::vector<array> LessEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -1289,7 +1326,8 @@ std::vector<array> LessEqual::jvp(
std::vector<array> Log::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1325,7 +1363,8 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
std::vector<array> Log1p::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1351,7 +1390,8 @@ std::pair<std::vector<array>, std::vector<int>> Log1p::vmap(
std::vector<array> LogicalNot::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1375,7 +1415,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalNot::vmap(
std::vector<array> LogicalAnd::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
if (argnums.size() > 1) {
@ -1406,7 +1447,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalAnd::vmap(
std::vector<array> LogicalOr::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
if (argnums.size() > 1) {
@ -1438,7 +1480,8 @@ std::pair<std::vector<array>, std::vector<int>> LogicalOr::vmap(
std::vector<array> LogAddExp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
auto a = primals[0];
auto b = primals[1];
auto s = sigmoid(subtract(a, b, stream()), stream());
@ -1483,7 +1526,8 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
std::vector<array> Matmul::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
auto& cotan = cotangents[0];
std::vector<int> reorder(cotan.ndim());
@ -1506,7 +1550,8 @@ std::vector<array> Matmul::vjp(
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& a = primals[0];
auto& b = primals[1];
std::vector<array> vjps;
@ -1547,7 +1592,8 @@ std::pair<std::vector<array>, std::vector<int>> Maximum::vmap(
std::vector<array> Minimum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& a = primals[0];
auto& b = primals[1];
std::vector<array> vjps;
@ -1601,7 +1647,8 @@ std::vector<array> Multiply::jvp(
std::vector<array> Multiply::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
@ -1619,7 +1666,8 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
std::vector<array> Negative::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1650,7 +1698,8 @@ std::pair<std::vector<array>, std::vector<int>> NotEqual::vmap(
std::vector<array> NotEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@ -1669,7 +1718,8 @@ std::vector<array> NotEqual::jvp(
std::vector<array> Pad::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(argnums.size() == 1 && argnums[0] == 0);
auto& cotan = cotangents[0];
@ -1717,7 +1767,8 @@ bool Pad::is_equivalent(const Primitive& other) const {
std::vector<array> Partition::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -1749,22 +1800,15 @@ bool Partition::is_equivalent(const Primitive& other) const {
std::vector<array> Power::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(multiply(
power(
primals[0],
subtract(primals[1], array(1, primals[0].dtype()), stream()),
stream()),
primals[1],
stream()));
outputs[0], divide(primals[1], primals[0], stream()), stream()));
} else {
vjps.push_back(multiply(
log(primals[0], stream()),
power(primals[0], primals[1], stream()),
stream()));
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
}
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
}
@ -1775,12 +1819,13 @@ std::vector<array> Power::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto jvp = vjp(primals, {tangents[0]}, {argnums[0]});
auto output = power(primals[0], primals[1], stream());
auto grads = vjp(primals, tangents, argnums, {output});
if (argnums.size() > 1) {
jvp[0] =
add(jvp[0], vjp(primals, {tangents[1]}, {argnums[1]})[0], stream());
return {add(grads[0], grads[1], stream())};
} else {
return grads;
}
return jvp;
}
std::pair<std::vector<array>, std::vector<int>> Power::vmap(
@ -1799,7 +1844,8 @@ std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
std::vector<array> QuantizedMatmul::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
@ -1902,7 +1948,8 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
std::vector<array> Reshape::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
assert(argnums[0] == 0);
@ -1927,7 +1974,8 @@ bool Reshape::is_equivalent(const Primitive& other) const {
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto in = primals[0];
std::vector<int> shape = in.shape();
@ -1997,15 +2045,10 @@ std::vector<array> Reduce::vjp(
}
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
array (*op)(const array&, const std::vector<int>&, bool, StreamOrDevice);
if (reduce_type_ == Reduce::Min) {
op = min;
} else {
op = max;
auto out = outputs[0];
if (out.ndim() != in.ndim()) {
out = expand_dims(out, axes_, stream());
}
auto out = op(in, axes_, true, stream());
auto mask = equal(in, out, stream());
auto normalizer = sum(mask, axes_, true, stream());
auto cotan_reshape = reshape(cotan, shape, stream());
@ -2032,7 +2075,8 @@ bool Reduce::is_equivalent(const Primitive& other) const {
std::vector<array> Round::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2076,7 +2120,8 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
std::vector<array> Scan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 1);
assert(argnums[0] == 0);
@ -2084,7 +2129,7 @@ std::vector<array> Scan::vjp(
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::Prod) {
// TODO: Make it numerically stable when we introduce where()
auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream());
auto prod = outputs[0];
auto partial_grads = multiply(prod, cotangents[0], stream());
auto accum_grads =
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
@ -2125,7 +2170,8 @@ bool Scatter::is_equivalent(const Primitive& other) const {
std::vector<array> Scatter::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum:
@ -2137,23 +2183,11 @@ std::vector<array> Scatter::vjp(
"[scatter] VJP not implemented for scatter_prod");
}
const array& result = outputs[0];
const array& values = primals[0];
const array& updates = primals.back();
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
// Store result of scatter if needed for reuse in vjp
auto get_result = [&]() {
switch (reduce_type_) {
case Scatter::Max:
return scatter_max(values, indices, updates, axes_, stream());
case Scatter::Min:
return scatter_min(values, indices, updates, axes_, stream());
default:
return array({});
}
};
array result = get_result();
std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
@ -2232,8 +2266,12 @@ std::vector<array> Scatter::jvp(
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto& s = outputs[0];
auto sprime =
multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());
return {multiply(cotangents[0], sprime, stream())};
}
std::vector<array> Sigmoid::jvp(
@ -2259,7 +2297,8 @@ std::pair<std::vector<array>, std::vector<int>> Sigmoid::vmap(
std::vector<array> Sign::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2283,7 +2322,8 @@ std::pair<std::vector<array>, std::vector<int>> Sign::vmap(
std::vector<array> Sin::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2307,7 +2347,8 @@ std::pair<std::vector<array>, std::vector<int>> Sin::vmap(
std::vector<array> Sinh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2345,7 +2386,8 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
std::vector<array> Slice::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
// Check inputs
assert(primals.size() == 1);
@ -2444,8 +2486,15 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
std::vector<array> Softmax::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
auto& s = outputs[0];
auto sv = multiply(s, cotangents[0], stream());
return {subtract(
sv,
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
}
std::vector<array> Softmax::jvp(
@ -2473,7 +2522,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
std::vector<array> Sort::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2503,7 +2553,8 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap(
std::vector<array> Split::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return {concatenate(cotangents, axis_, stream())};
}
@ -2522,7 +2573,8 @@ bool Split::is_equivalent(const Primitive& other) const {
std::vector<array> Square::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2549,29 +2601,34 @@ std::pair<std::vector<array>, std::vector<int>> Square::vmap(
std::vector<array> Sqrt::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
auto dtype = primals[0].dtype();
if (recip_) {
auto one_over_x_root_x = divide(outputs[0], primals[0], stream());
return {multiply(
multiply(array(-0.5, dtype), cotangents[0], stream()),
one_over_x_root_x,
stream())};
} else {
return {divide(
multiply(array(0.5, dtype), cotangents[0], stream()),
outputs[0],
stream())};
}
}
std::vector<array> Sqrt::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(tangents.size() == 1);
auto dtype = primals[0].dtype();
if (recip_) {
auto one_over_x_root_x =
divide(rsqrt(primals[0], stream()), primals[0], stream());
return {multiply(
multiply(array(-0.5, dtype), tangents[0], stream()),
one_over_x_root_x,
stream())};
return vjp(primals, tangents, argnums, {rsqrt(primals[0], stream())});
} else {
return vjp(primals, tangents, argnums, {sqrt(primals[0], stream())});
}
return {divide(
multiply(array(0.5, dtype), tangents[0], stream()),
sqrt(primals[0], stream()),
stream())};
}
std::pair<std::vector<array>, std::vector<int>> Sqrt::vmap(
@ -2599,7 +2656,8 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
std::vector<array> Subtract::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
auto vjp = cotangents[0];
@ -2636,7 +2694,8 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
std::vector<array> Tan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2661,7 +2720,8 @@ std::pair<std::vector<array>, std::vector<int>> Tan::vmap(
std::vector<array> Tanh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
@ -2686,7 +2746,8 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
std::vector<array> Transpose::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
std::vector<int> iaxes(axes_.size());

View File

@ -21,7 +21,8 @@
std::vector<array> vjp( \
const std::vector<array>& primals, \
const std::vector<array>& cotangents, \
const std::vector<int>& argnums) override;
const std::vector<int>& argnums, \
const std::vector<array>& outputs) override;
#define DEFINE_PRINT(PRIMITIVE) \
void print(std::ostream& os) override { \
@ -78,7 +79,8 @@ class Primitive {
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums);
const std::vector<int>& argnums,
const std::vector<array>& outputs);
/**
* The primitive must know how to vectorize itself across
@ -464,7 +466,8 @@ class Convolution : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) override;
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
@ -919,7 +922,8 @@ class Matmul : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) override;
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
@ -1153,7 +1157,8 @@ class Reduce : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) override;
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
void print(std::ostream& os) override {
switch (reduce_type_) {

View File

@ -443,7 +443,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums);
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();

View File

@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase):
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
else:
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))

View File

@ -1005,7 +1005,8 @@ class TestArray(mlx_tests.MLXTestCase):
index_y = mx.array([3, 3, 1, 2])
u = mx.random.uniform(shape=(4,))
a = a.at[index_x, index_y].add(u)
self.assertEqual(a.sum().item(), u.sum().item())
self.assertTrue(mx.allclose(a.sum(), u.sum()))
self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5)
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
# Test all array.at ops