From 18aa9213886b0e4ed5d87a394fad0dbbe839d4e1 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 16:24:35 -0700 Subject: [PATCH] WIP --- mlx/fast.cpp | 2 ++ mlx/primitives.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index e88527a8e..9b6a70d38 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -127,6 +127,7 @@ std::vector RMSNorm::vjp( assert(primals.size() == 2); assert(outputs.size() == 1); assert(cotangents.size() == 1); + (void)outputs; auto s = stream(); auto fallback = [eps = eps_, s](const std::vector& inputs) { @@ -269,6 +270,7 @@ std::vector LayerNorm::vjp( assert(primals.size() == 3); assert(outputs.size() == 1); assert(cotangents.size() == 1); + (void)outputs; auto s = stream(); auto fallback = [eps = eps_, s](const std::vector& inputs) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0faa9f407..1c1ecdb48 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -230,6 +230,7 @@ std::vector Abs::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], sign(primals[0], stream()), stream())}; } @@ -383,6 +384,7 @@ std::vector ArcCos::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); array denom = negative(rsqrt(t, stream()), stream()); @@ -411,6 +413,7 @@ std::vector ArcCosh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(square(primals[0], stream()), one, stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -438,6 +441,7 @@ std::vector ArcSin::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -465,6 +469,7 @@ std::vector ArcSinh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = add(square(primals[0], stream()), one, stream()); return {multiply(tangents[0], rsqrt(t, stream()), stream())}; @@ -492,6 +497,7 @@ std::vector ArcTan::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = add(one, square(primals[0], stream()), stream()); return {divide(tangents[0], t, stream())}; @@ -539,6 +545,7 @@ std::vector ArcTan2::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() == 2); + (void)argnums; const auto& s = stream(); const array& x1 = primals[0]; @@ -575,6 +582,7 @@ std::vector ArcTanh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array one = array(1., primals[0].dtype()); array t = subtract(one, square(primals[0], stream()), stream()); return {divide(tangents[0], t, stream())}; @@ -725,6 +733,7 @@ std::vector AsStrided::vjp( const std::vector& argnums, const std::vector&) { assert(argnums.size() == 1); + (void)argnums; // Extract the sizes and cast them to ints int grad_size = primals[0].size(); @@ -754,6 +763,7 @@ std::vector AsStrided::jvp( const std::vector& tangents, const std::vector& /* argnums */) { assert(primals.size() == 1); + (void)primals; return {as_strided(tangents[0], shape_, strides_, offset_, stream())}; } @@ -787,6 +797,7 @@ std::vector BitwiseBinary::jvp( const std::vector& tangents, const std::vector& argnums) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(tangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -942,6 +953,7 @@ std::vector Ceil::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -1581,6 +1593,8 @@ std::vector Copy::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return cotangents; } @@ -1590,6 +1604,8 @@ std::vector Copy::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return tangents; } @@ -1615,6 +1631,7 @@ std::vector Cos::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply( tangents[0], negative(sin(primals[0], stream()), stream()), stream())}; } @@ -1641,6 +1658,7 @@ std::vector Cosh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], sinh(primals[0], stream()), stream())}; } @@ -1881,6 +1899,7 @@ std::vector Erf::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); auto scale = multiply(array(M_2_SQRTPI, dtype), tangents[0], stream()); return {multiply( @@ -1915,6 +1934,7 @@ std::vector ErfInv::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); auto scale = multiply(array(1.0 / M_2_SQRTPI, dtype), tangents[0], stream()); return {multiply( @@ -1945,6 +1965,7 @@ std::vector Exp::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], exp(primals[0], stream()), stream())}; } @@ -1973,6 +1994,7 @@ std::vector Expm1::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], exp(primals[0], stream()), stream())}; } @@ -2181,6 +2203,7 @@ std::vector FFT::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); @@ -2260,6 +2283,8 @@ std::vector FFT::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; auto& tan = tangents[0]; if (real_ & inverse_) { return {fft::irfftn(tan, stream())}; @@ -2286,6 +2311,7 @@ std::vector Floor::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -2304,6 +2330,7 @@ std::vector Full::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(cotangents[0], primals[0], stream())}; } @@ -2313,6 +2340,8 @@ std::vector Full::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return tangents; } @@ -2568,6 +2597,7 @@ std::vector Imag::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply( array(complex64_t{0.0f, 1.0f}, primals[0].dtype()), cotangents[0], @@ -2580,6 +2610,8 @@ std::vector Imag::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {imag(tangents[0], stream())}; } @@ -2659,6 +2691,7 @@ std::vector Log::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto out = divide(tangents[0], primals[0], stream()); if (base_ != Base::e) { auto scale = 1 / std::log(base_ == Base::ten ? 10.0f : 2.0f); @@ -2696,6 +2729,7 @@ std::vector Log1p::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto dtype = primals[0].dtype(); return {divide( tangents[0], add(array(1.0f, dtype), primals[0], stream()), stream())}; @@ -2723,6 +2757,8 @@ std::vector LogicalNot::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {zeros_like(tangents[0], stream())}; } @@ -2740,6 +2776,7 @@ std::vector LogicalAnd::vjp( const std::vector& argnums, const std::vector&) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -2753,6 +2790,7 @@ std::vector LogicalAnd::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -2772,6 +2810,7 @@ std::vector LogicalOr::vjp( const std::vector& argnums, const std::vector&) { assert(primals.size() == 2); + (void)primals; std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { vjps.push_back(vjps.back()); @@ -2785,6 +2824,7 @@ std::vector LogicalOr::jvp( const std::vector& argnums) { assert(primals.size() == 2); assert(argnums.size() <= 2); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -3154,6 +3194,8 @@ std::vector Negative::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {negative(tangents[0], stream())}; } @@ -3198,6 +3240,7 @@ std::vector Pad::vjp( const std::vector& argnums, const std::vector&) { assert(argnums.size() == 1 && argnums[0] == 0); + (void)argnums; auto& cotan = cotangents[0]; Shape start(cotan.ndim(), 0); @@ -3218,6 +3261,7 @@ std::vector Pad::jvp( const std::vector& tangents, const std::vector& argnums) { assert(argnums.size() == 1 && argnums[0] == 0); + (void)argnums; return { pad(tangents[0], @@ -3639,6 +3683,7 @@ std::vector Real::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {astype(cotangents[0], primals[0].dtype(), stream())}; } @@ -3648,6 +3693,8 @@ std::vector Real::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {real(tangents[0], stream())}; } @@ -3688,6 +3735,7 @@ std::vector Reshape::vjp( assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); + (void)argnums; return {reshape(cotangents[0], primals[0].shape(), stream())}; } @@ -3698,6 +3746,8 @@ std::vector Reshape::jvp( assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); + (void)primals; + (void)argnums; return {reshape(tangents[0], shape_, stream())}; } @@ -3891,6 +3941,7 @@ std::vector Round::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros_like(primals[0], stream())}; } @@ -3926,6 +3977,7 @@ std::vector Scan::vjp( const std::vector& outputs) { assert(primals.size() == 1); assert(argnums[0] == 0); + (void)argnums; if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; @@ -4027,6 +4079,7 @@ std::vector Scan::jvp( const std::vector& argnums) { assert(tangents.size() == 1); assert(argnums[0] == 0); + (void)argnums; if (reduce_type_ == Scan::Sum) { return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())}; @@ -4346,6 +4399,7 @@ std::vector Sigmoid::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; auto s = sigmoid(primals[0], stream()); auto sprime = multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream()); @@ -4374,6 +4428,7 @@ std::vector Sign::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {zeros(primals[0].shape(), primals[0].dtype(), stream())}; } @@ -4399,6 +4454,7 @@ std::vector Sin::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], cos(primals[0], stream()), stream())}; } @@ -4424,6 +4480,7 @@ std::vector Sinh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; return {multiply(tangents[0], cosh(primals[0], stream()), stream())}; } @@ -4469,6 +4526,7 @@ std::vector Slice::jvp( const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 1); + (void)primals; return {slice(tangents[0], start_indices_, end_indices_, strides_, stream())}; } @@ -4566,6 +4624,7 @@ std::vector SliceUpdate::jvp( const std::vector& /* argnums */) { // Check inputs assert(primals.size() == 2); + (void)primals; return {slice_update( tangents[0], tangents[1], @@ -4748,6 +4807,7 @@ std::vector Softmax::vjp( const std::vector& outputs) { assert(primals.size() == 1); assert(cotangents.size() == 1); + (void)primals; auto& s = outputs[0]; auto sv = multiply(s, cotangents[0], stream()); return {subtract( @@ -5022,6 +5082,7 @@ std::vector Tan::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array cos_sq = square(cos(primals[0], stream()), stream()); return {divide(tangents[0], cos_sq, stream())}; } @@ -5048,6 +5109,7 @@ std::vector Tanh::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)argnums; array cosh_sq = square(cosh(primals[0], stream()), stream()); return {divide(tangents[0], cosh_sq, stream())}; } @@ -5409,6 +5471,8 @@ std::vector Transpose::vjp( const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; std::vector iaxes(axes_.size()); for (int i = 0; i < std::ssize(axes_); ++i) { iaxes[axes_[i]] = i; @@ -5422,6 +5486,7 @@ std::vector Transpose::jvp( const std::vector& /* argnums */) { assert(primals.size() == 1); assert(tangents.size() == 1); + (void)primals; return {transpose(tangents[0], axes_, stream())}; } @@ -5556,6 +5621,8 @@ std::vector Hadamard::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(argnums.size() == 1); + (void)primals; + (void)argnums; return {hadamard_transform(tangents[0], scale_, stream())}; }