mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP
This commit is contained in:
@@ -127,6 +127,7 @@ std::vector<array> RMSNorm::vjp(
|
||||
assert(primals.size() == 2);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
(void)outputs;
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
@@ -269,6 +270,7 @@ std::vector<array> LayerNorm::vjp(
|
||||
assert(primals.size() == 3);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
(void)outputs;
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
|
||||
@@ -230,6 +230,7 @@ std::vector<array> Abs::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], sign(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -383,6 +384,7 @@ std::vector<array> ArcCos::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = subtract(one, square(primals[0], stream()), stream());
|
||||
array denom = negative(rsqrt(t, stream()), stream());
|
||||
@@ -411,6 +413,7 @@ std::vector<array> ArcCosh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = subtract(square(primals[0], stream()), one, stream());
|
||||
return {multiply(tangents[0], rsqrt(t, stream()), stream())};
|
||||
@@ -438,6 +441,7 @@ std::vector<array> ArcSin::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = subtract(one, square(primals[0], stream()), stream());
|
||||
return {multiply(tangents[0], rsqrt(t, stream()), stream())};
|
||||
@@ -465,6 +469,7 @@ std::vector<array> ArcSinh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = add(square(primals[0], stream()), one, stream());
|
||||
return {multiply(tangents[0], rsqrt(t, stream()), stream())};
|
||||
@@ -492,6 +497,7 @@ std::vector<array> ArcTan::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = add(one, square(primals[0], stream()), stream());
|
||||
return {divide(tangents[0], t, stream())};
|
||||
@@ -539,6 +545,7 @@ std::vector<array> ArcTan2::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() == 2);
|
||||
(void)argnums;
|
||||
|
||||
const auto& s = stream();
|
||||
const array& x1 = primals[0];
|
||||
@@ -575,6 +582,7 @@ std::vector<array> ArcTanh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array one = array(1., primals[0].dtype());
|
||||
array t = subtract(one, square(primals[0], stream()), stream());
|
||||
return {divide(tangents[0], t, stream())};
|
||||
@@ -725,6 +733,7 @@ std::vector<array> AsStrided::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
|
||||
// Extract the sizes and cast them to ints
|
||||
int grad_size = primals[0].size();
|
||||
@@ -754,6 +763,7 @@ std::vector<array> AsStrided::jvp(
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& /* argnums */) {
|
||||
assert(primals.size() == 1);
|
||||
(void)primals;
|
||||
|
||||
return {as_strided(tangents[0], shape_, strides_, offset_, stream())};
|
||||
}
|
||||
@@ -787,6 +797,7 @@ std::vector<array> BitwiseBinary::jvp(
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
(void)primals;
|
||||
std::vector<array> vjps = {zeros_like(tangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
vjps.push_back(vjps.back());
|
||||
@@ -942,6 +953,7 @@ std::vector<array> Ceil::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
@@ -1581,6 +1593,8 @@ std::vector<array> Copy::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return cotangents;
|
||||
}
|
||||
|
||||
@@ -1590,6 +1604,8 @@ std::vector<array> Copy::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return tangents;
|
||||
}
|
||||
|
||||
@@ -1615,6 +1631,7 @@ std::vector<array> Cos::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(
|
||||
tangents[0], negative(sin(primals[0], stream()), stream()), stream())};
|
||||
}
|
||||
@@ -1641,6 +1658,7 @@ std::vector<array> Cosh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], sinh(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -1881,6 +1899,7 @@ std::vector<array> Erf::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto dtype = primals[0].dtype();
|
||||
auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream());
|
||||
return {multiply(
|
||||
@@ -1915,6 +1934,7 @@ std::vector<array> ErfInv::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto dtype = primals[0].dtype();
|
||||
auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream());
|
||||
return {multiply(
|
||||
@@ -1945,6 +1965,7 @@ std::vector<array> Exp::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -1973,6 +1994,7 @@ std::vector<array> Expm1::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], exp(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -2181,6 +2203,7 @@ std::vector<array> FFT::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto& in = primals[0];
|
||||
std::vector<int> axes(axes_.begin(), axes_.end());
|
||||
|
||||
@@ -2260,6 +2283,8 @@ std::vector<array> FFT::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
auto& tan = tangents[0];
|
||||
if (real_ & inverse_) {
|
||||
return {fft::irfftn(tan, stream())};
|
||||
@@ -2286,6 +2311,7 @@ std::vector<array> Floor::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
@@ -2304,6 +2330,7 @@ std::vector<array> Full::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(cotangents[0], primals[0], stream())};
|
||||
}
|
||||
|
||||
@@ -2313,6 +2340,8 @@ std::vector<array> Full::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return tangents;
|
||||
}
|
||||
|
||||
@@ -2568,6 +2597,7 @@ std::vector<array> Imag::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(
|
||||
array(complex64_t{0.0f, 1.0f}, primals[0].dtype()),
|
||||
cotangents[0],
|
||||
@@ -2580,6 +2610,8 @@ std::vector<array> Imag::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {imag(tangents[0], stream())};
|
||||
}
|
||||
|
||||
@@ -2659,6 +2691,7 @@ std::vector<array> Log::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto out = divide(tangents[0], primals[0], stream());
|
||||
if (base_ != Base::e) {
|
||||
auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f);
|
||||
@@ -2696,6 +2729,7 @@ std::vector<array> Log1p::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto dtype = primals[0].dtype();
|
||||
return {divide(
|
||||
tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream())};
|
||||
@@ -2723,6 +2757,8 @@ std::vector<array> LogicalNot::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {zeros_like(tangents[0], stream())};
|
||||
}
|
||||
|
||||
@@ -2740,6 +2776,7 @@ std::vector<array> LogicalAnd::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
(void)primals;
|
||||
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
vjps.push_back(vjps.back());
|
||||
@@ -2753,6 +2790,7 @@ std::vector<array> LogicalAnd::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() <= 2);
|
||||
(void)argnums;
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
@@ -2772,6 +2810,7 @@ std::vector<array> LogicalOr::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
(void)primals;
|
||||
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
|
||||
if (argnums.size() > 1) {
|
||||
vjps.push_back(vjps.back());
|
||||
@@ -2785,6 +2824,7 @@ std::vector<array> LogicalOr::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() <= 2);
|
||||
(void)argnums;
|
||||
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
@@ -3154,6 +3194,8 @@ std::vector<array> Negative::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {negative(tangents[0], stream())};
|
||||
}
|
||||
|
||||
@@ -3198,6 +3240,7 @@ std::vector<array> Pad::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||
(void)argnums;
|
||||
|
||||
auto& cotan = cotangents[0];
|
||||
Shape start(cotan.ndim(), 0);
|
||||
@@ -3218,6 +3261,7 @@ std::vector<array> Pad::jvp(
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(argnums.size() == 1 && argnums[0] == 0);
|
||||
(void)argnums;
|
||||
|
||||
return {
|
||||
pad(tangents[0],
|
||||
@@ -3639,6 +3683,7 @@ std::vector<array> Real::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {astype(cotangents[0], primals[0].dtype(), stream())};
|
||||
}
|
||||
|
||||
@@ -3648,6 +3693,8 @@ std::vector<array> Real::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {real(tangents[0], stream())};
|
||||
}
|
||||
|
||||
@@ -3688,6 +3735,7 @@ std::vector<array> Reshape::vjp(
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
(void)argnums;
|
||||
return {reshape(cotangents[0], primals[0].shape(), stream())};
|
||||
}
|
||||
|
||||
@@ -3698,6 +3746,8 @@ std::vector<array> Reshape::jvp(
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {reshape(tangents[0], shape_, stream())};
|
||||
}
|
||||
|
||||
@@ -3891,6 +3941,7 @@ std::vector<array> Round::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
@@ -3926,6 +3977,7 @@ std::vector<array> Scan::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
(void)argnums;
|
||||
|
||||
if (reduce_type_ == Scan::Sum) {
|
||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||
@@ -4027,6 +4079,7 @@ std::vector<array> Scan::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(tangents.size() == 1);
|
||||
assert(argnums[0] == 0);
|
||||
(void)argnums;
|
||||
|
||||
if (reduce_type_ == Scan::Sum) {
|
||||
return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())};
|
||||
@@ -4346,6 +4399,7 @@ std::vector<array> Sigmoid::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
auto s = sigmoid(primals[0], stream());
|
||||
auto sprime =
|
||||
multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());
|
||||
@@ -4374,6 +4428,7 @@ std::vector<array> Sign::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {zeros(primals[0].shape(), primals[0].dtype(), stream())};
|
||||
}
|
||||
|
||||
@@ -4399,6 +4454,7 @@ std::vector<array> Sin::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], cos(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -4424,6 +4480,7 @@ std::vector<array> Sinh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
return {multiply(tangents[0], cosh(primals[0], stream()), stream())};
|
||||
}
|
||||
|
||||
@@ -4469,6 +4526,7 @@ std::vector<array> Slice::jvp(
|
||||
const std::vector<int>& /* argnums */) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 1);
|
||||
(void)primals;
|
||||
return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())};
|
||||
}
|
||||
|
||||
@@ -4566,6 +4624,7 @@ std::vector<array> SliceUpdate::jvp(
|
||||
const std::vector<int>& /* argnums */) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 2);
|
||||
(void)primals;
|
||||
return {slice_update(
|
||||
tangents[0],
|
||||
tangents[1],
|
||||
@@ -4748,6 +4807,7 @@ std::vector<array> Softmax::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
(void)primals;
|
||||
auto& s = outputs[0];
|
||||
auto sv = multiply(s, cotangents[0], stream());
|
||||
return {subtract(
|
||||
@@ -5022,6 +5082,7 @@ std::vector<array> Tan::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array cos_sq = square(cos(primals[0], stream()), stream());
|
||||
return {divide(tangents[0], cos_sq, stream())};
|
||||
}
|
||||
@@ -5048,6 +5109,7 @@ std::vector<array> Tanh::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)argnums;
|
||||
array cosh_sq = square(cosh(primals[0], stream()), stream());
|
||||
return {divide(tangents[0], cosh_sq, stream())};
|
||||
}
|
||||
@@ -5409,6 +5471,8 @@ std::vector<array> Transpose::vjp(
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
std::vector<int> iaxes(axes_.size());
|
||||
for (int i = 0; i < std::ssize(axes_); ++i) {
|
||||
iaxes[axes_[i]] = i;
|
||||
@@ -5422,6 +5486,7 @@ std::vector<array> Transpose::jvp(
|
||||
const std::vector<int>& /* argnums */) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
(void)primals;
|
||||
return {transpose(tangents[0], axes_, stream())};
|
||||
}
|
||||
|
||||
@@ -5556,6 +5621,8 @@ std::vector<array> Hadamard::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
(void)primals;
|
||||
(void)argnums;
|
||||
return {hadamard_transform(tangents[0], scale_, stream())};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user