diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9876a2443..9cb85c67b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -51,29 +51,31 @@ std::tuple vmap_binary_op( } // namespace std::vector Primitive::jvp( - const std::vector& primals, - const std::vector& tangents, - const std::vector& argnums) { + const std::vector&, + const std::vector&, + const std::vector&) { throw std::invalid_argument("Primitive's jvp not implemented."); }; std::vector Primitive::vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums) { + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&) { throw std::invalid_argument("Primitive's vjp not implemented."); }; std::pair, std::vector> Primitive::vmap( - const std::vector& inputs, - const std::vector& axes) { + const std::vector&, + const std::vector&) { throw std::invalid_argument("Primitive's vmap not implemented."); }; std::vector Abs::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -106,7 +108,8 @@ std::vector Add::jvp( std::vector Add::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { if (argnums.size() == 1) { return cotangents; } else { @@ -131,7 +134,8 @@ bool Arange::is_equivalent(const Primitive& other) const { std::vector ArcCos::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -158,7 +162,8 @@ std::pair, std::vector> ArcCos::vmap( std::vector ArcCosh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -184,7 +189,8 @@ std::pair, std::vector> ArcCosh::vmap( std::vector ArcSin::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -210,7 +216,8 @@ std::pair, std::vector> ArcSin::vmap( std::vector ArcSinh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -236,7 +243,8 @@ std::pair, std::vector> ArcSinh::vmap( std::vector ArcTan::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -262,7 +270,8 @@ std::pair, std::vector> ArcTan::vmap( std::vector ArcTanh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -322,7 +331,8 @@ bool ArgSort::is_equivalent(const Primitive& other) const { std::vector AsType::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { if (cotangents[0].dtype() != dtype_) { throw std::invalid_argument( "[astype] Type of cotangentsgent does not much primal output type."); @@ -351,7 +361,8 @@ bool AsType::is_equivalent(const Primitive& other) const { std::vector AsStrided::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(argnums.size() == 1); // Extract the sizes and cast them to ints @@ -395,7 +406,8 @@ bool AsStrided::is_equivalent(const Primitive& other) const { std::vector Broadcast::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(argnums.size() == 1); // Reduce cotangents to the shape of the primal @@ -445,7 +457,8 @@ bool Broadcast::is_equivalent(const Primitive& other) const { std::vector Ceil::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -469,7 +482,8 @@ std::pair, std::vector> Ceil::vmap( std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { auto& cotan = cotangents[0]; std::vector start(cotan.ndim(), 0); std::vector stop = cotan.shape(); @@ -544,7 +558,8 @@ bool Concatenate::is_equivalent(const Primitive& other) const { std::vector Convolution::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 2); std::vector grads; @@ -661,7 +676,8 @@ bool Convolution::is_equivalent(const Primitive& other) const { std::vector Copy::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); return cotangents; @@ -687,7 +703,8 @@ std::pair, std::vector> Copy::vmap( std::vector Cos::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return {jvp(primals, cotangents, argnums)}; } @@ -712,7 +729,8 @@ std::pair, std::vector> Cos::vmap( std::vector Cosh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -736,7 +754,8 @@ std::pair, std::vector> Cosh::vmap( std::vector Divide::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { if (arg == 0) { @@ -756,7 +775,8 @@ std::vector Divide::vjp( std::vector DivMod::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -812,7 +832,8 @@ std::pair, std::vector> Divide::vmap( std::vector Remainder::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { if (arg == 0) { @@ -865,7 +886,8 @@ std::pair, std::vector> Equal::vmap( std::vector Equal::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -884,7 +906,8 @@ std::vector Equal::jvp( std::vector Erf::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -913,8 +936,13 @@ std::pair, std::vector> Erf::vmap( std::vector ErfInv::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { - return jvp(primals, cotangents, argnums); + const std::vector& argnums, + const std::vector& outputs) { + auto dtype = primals[0].dtype(); + auto scale = + multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream()); + return { + multiply(scale, exp(square(outputs[0], stream()), stream()), stream())}; } std::vector ErfInv::jvp( @@ -942,8 +970,9 @@ std::pair, std::vector> ErfInv::vmap( std::vector Exp::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { - return jvp(primals, cotangents, argnums); + const std::vector& argnums, + const std::vector& outputs) { + return {multiply(cotangents[0], outputs[0], stream())}; } std::vector Exp::jvp( @@ -997,7 +1026,8 @@ std::pair, std::vector> FFT::vmap( std::vector FFT::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); auto& in = primals[0]; @@ -1050,7 +1080,8 @@ std::vector FFT::jvp( std::vector Floor::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1074,7 +1105,8 @@ std::pair, std::vector> Floor::vmap( std::vector Full::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); return {multiply(cotangents[0], primals[0], stream())}; @@ -1155,7 +1187,8 @@ std::pair, std::vector> Gather::vmap( std::vector Gather::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { if (argnums.size() > 1 || argnums[0] != 0) { throw std::invalid_argument( "[gather] Cannot calculate VJP with respect to indices."); @@ -1192,7 +1225,8 @@ std::pair, std::vector> Greater::vmap( std::vector Greater::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -1218,7 +1252,8 @@ std::pair, std::vector> GreaterEqual::vmap( std::vector GreaterEqual::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -1244,7 +1279,8 @@ std::pair, std::vector> Less::vmap( std::vector Less::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -1270,7 +1306,8 @@ std::pair, std::vector> LessEqual::vmap( std::vector LessEqual::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -1289,7 +1326,8 @@ std::vector LessEqual::jvp( std::vector Log::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1325,7 +1363,8 @@ std::pair, std::vector> Log::vmap( std::vector Log1p::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1351,7 +1390,8 @@ std::pair, std::vector> Log1p::vmap( std::vector LogicalNot::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1375,7 +1415,8 @@ std::pair, std::vector> LogicalNot::vmap( std::vector LogicalAnd::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 2); std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { @@ -1406,7 +1447,8 @@ std::pair, std::vector> LogicalAnd::vmap( std::vector LogicalOr::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 2); std::vector vjps = {zeros_like(cotangents[0], stream())}; if (argnums.size() > 1) { @@ -1438,7 +1480,8 @@ std::pair, std::vector> LogicalOr::vmap( std::vector LogAddExp::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { auto a = primals[0]; auto b = primals[1]; auto s = sigmoid(subtract(a, b, stream()), stream()); @@ -1483,7 +1526,8 @@ std::pair, std::vector> LogAddExp::vmap( std::vector Matmul::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; auto& cotan = cotangents[0]; std::vector reorder(cotan.ndim()); @@ -1506,7 +1550,8 @@ std::vector Matmul::vjp( std::vector Maximum::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { auto& a = primals[0]; auto& b = primals[1]; std::vector vjps; @@ -1547,7 +1592,8 @@ std::pair, std::vector> Maximum::vmap( std::vector Minimum::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { auto& a = primals[0]; auto& b = primals[1]; std::vector vjps; @@ -1601,7 +1647,8 @@ std::vector Multiply::jvp( std::vector Multiply::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); @@ -1619,7 +1666,8 @@ std::pair, std::vector> Multiply::vmap( std::vector Negative::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1650,7 +1698,8 @@ std::pair, std::vector> NotEqual::vmap( std::vector NotEqual::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { vjps.push_back(zeros_like(primals[arg], stream())); @@ -1669,7 +1718,8 @@ std::vector NotEqual::jvp( std::vector Pad::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(argnums.size() == 1 && argnums[0] == 0); auto& cotan = cotangents[0]; @@ -1717,7 +1767,8 @@ bool Pad::is_equivalent(const Primitive& other) const { std::vector Partition::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -1749,22 +1800,15 @@ bool Partition::is_equivalent(const Primitive& other) const { std::vector Power::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector& outputs) { std::vector vjps; for (auto arg : argnums) { if (arg == 0) { vjps.push_back(multiply( - power( - primals[0], - subtract(primals[1], array(1, primals[0].dtype()), stream()), - stream()), - primals[1], - stream())); + outputs[0], divide(primals[1], primals[0], stream()), stream())); } else { - vjps.push_back(multiply( - log(primals[0], stream()), - power(primals[0], primals[1], stream()), - stream())); + vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream())); } vjps.back() = multiply(cotangents[0], vjps.back(), stream()); } @@ -1775,12 +1819,13 @@ std::vector Power::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - auto jvp = vjp(primals, {tangents[0]}, {argnums[0]}); + auto output = power(primals[0], primals[1], stream()); + auto grads = vjp(primals, tangents, argnums, {output}); if (argnums.size() > 1) { - jvp[0] = - add(jvp[0], vjp(primals, {tangents[1]}, {argnums[1]})[0], stream()); + return {add(grads[0], grads[1], stream())}; + } else { + return grads; } - return jvp; } std::pair, std::vector> Power::vmap( @@ -1799,7 +1844,8 @@ std::pair, std::vector> QuantizedMatmul::vmap( std::vector QuantizedMatmul::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; // We rely on the fact that w is always 2D so transpose is simple @@ -1902,7 +1948,8 @@ std::pair, std::vector> Reshape::vmap( std::vector Reshape::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); assert(argnums[0] == 0); @@ -1927,7 +1974,8 @@ bool Reshape::is_equivalent(const Primitive& other) const { std::vector Reduce::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector& outputs) { auto in = primals[0]; std::vector shape = in.shape(); @@ -1997,15 +2045,10 @@ std::vector Reduce::vjp( } } else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) { - array (*op)(const array&, const std::vector&, bool, StreamOrDevice); - - if (reduce_type_ == Reduce::Min) { - op = min; - } else { - op = max; + auto out = outputs[0]; + if (out.ndim() != in.ndim()) { + out = expand_dims(out, axes_, stream()); } - - auto out = op(in, axes_, true, stream()); auto mask = equal(in, out, stream()); auto normalizer = sum(mask, axes_, true, stream()); auto cotan_reshape = reshape(cotan, shape, stream()); @@ -2032,7 +2075,8 @@ bool Reduce::is_equivalent(const Primitive& other) const { std::vector Round::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2076,7 +2120,8 @@ std::pair, std::vector> Scan::vmap( std::vector Scan::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector& outputs) { assert(primals.size() == 1); assert(argnums[0] == 0); @@ -2084,7 +2129,7 @@ std::vector Scan::vjp( return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; } else if (reduce_type_ == Scan::Prod) { // TODO: Make it numerically stable when we introduce where() - auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream()); + auto prod = outputs[0]; auto partial_grads = multiply(prod, cotangents[0], stream()); auto accum_grads = cumsum(partial_grads, axis_, !reverse_, inclusive_, stream()); @@ -2125,7 +2170,8 @@ bool Scatter::is_equivalent(const Primitive& other) const { std::vector Scatter::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector& outputs) { switch (reduce_type_) { case Scatter::None: case Scatter::Sum: @@ -2137,23 +2183,11 @@ std::vector Scatter::vjp( "[scatter] VJP not implemented for scatter_prod"); } + const array& result = outputs[0]; const array& values = primals[0]; const array& updates = primals.back(); const std::vector indices(primals.begin() + 1, primals.end() - 1); - // Store result of scatter if needed for reuse in vjp - auto get_result = [&]() { - switch (reduce_type_) { - case Scatter::Max: - return scatter_max(values, indices, updates, axes_, stream()); - case Scatter::Min: - return scatter_min(values, indices, updates, axes_, stream()); - default: - return array({}); - } - }; - array result = get_result(); - std::vector vjps; for (auto num : argnums) { // Gradient wrt to the input array @@ -2232,8 +2266,12 @@ std::vector Scatter::jvp( std::vector Sigmoid::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { - return jvp(primals, cotangents, argnums); + const std::vector& argnums, + const std::vector& outputs) { + auto& s = outputs[0]; + auto sprime = + multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream()); + return {multiply(cotangents[0], sprime, stream())}; } std::vector Sigmoid::jvp( @@ -2259,7 +2297,8 @@ std::pair, std::vector> Sigmoid::vmap( std::vector Sign::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2283,7 +2322,8 @@ std::pair, std::vector> Sign::vmap( std::vector Sin::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2307,7 +2347,8 @@ std::pair, std::vector> Sin::vmap( std::vector Sinh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2345,7 +2386,8 @@ std::pair, std::vector> Slice::vmap( std::vector Slice::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { // Check inputs assert(primals.size() == 1); @@ -2444,8 +2486,15 @@ std::pair, std::vector> Softmax::vmap( std::vector Softmax::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { - return jvp(primals, cotangents, argnums); + const std::vector& argnums, + const std::vector& outputs) { + assert(primals.size() == 1); + assert(cotangents.size() == 1); + auto& s = outputs[0]; + auto sv = multiply(s, cotangents[0], stream()); + return {subtract( + sv, + multiply(s, sum(sv, std::vector{-1}, true, stream()), stream()))}; } std::vector Softmax::jvp( @@ -2473,7 +2522,8 @@ std::pair, std::vector> Sort::vmap( std::vector Sort::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2503,7 +2553,8 @@ std::pair, std::vector> Split::vmap( std::vector Split::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return {concatenate(cotangents, axis_, stream())}; } @@ -2522,7 +2573,8 @@ bool Split::is_equivalent(const Primitive& other) const { std::vector Square::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2549,29 +2601,34 @@ std::pair, std::vector> Square::vmap( std::vector Sqrt::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { - return jvp(primals, cotangents, argnums); + const std::vector& argnums, + const std::vector& outputs) { + assert(primals.size() == 1); + assert(cotangents.size() == 1); + auto dtype = primals[0].dtype(); + if (recip_) { + auto one_over_x_root_x = divide(outputs[0], primals[0], stream()); + return {multiply( + multiply(array(-0.5, dtype), cotangents[0], stream()), + one_over_x_root_x, + stream())}; + } else { + return {divide( + multiply(array(0.5, dtype), cotangents[0], stream()), + outputs[0], + stream())}; + } } std::vector Sqrt::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - assert(primals.size() == 1); - assert(tangents.size() == 1); - auto dtype = primals[0].dtype(); if (recip_) { - auto one_over_x_root_x = - divide(rsqrt(primals[0], stream()), primals[0], stream()); - return {multiply( - multiply(array(-0.5, dtype), tangents[0], stream()), - one_over_x_root_x, - stream())}; + return vjp(primals, tangents, argnums, {rsqrt(primals[0], stream())}); + } else { + return vjp(primals, tangents, argnums, {sqrt(primals[0], stream())}); } - return {divide( - multiply(array(0.5, dtype), tangents[0], stream()), - sqrt(primals[0], stream()), - stream())}; } std::pair, std::vector> Sqrt::vmap( @@ -2599,7 +2656,8 @@ std::pair, std::vector> StopGradient::vmap( std::vector Subtract::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { std::vector vjps; for (auto arg : argnums) { auto vjp = cotangents[0]; @@ -2636,7 +2694,8 @@ std::pair, std::vector> Subtract::vmap( std::vector Tan::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2661,7 +2720,8 @@ std::pair, std::vector> Tan::vmap( std::vector Tanh::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { return jvp(primals, cotangents, argnums); } @@ -2686,7 +2746,8 @@ std::pair, std::vector> Tanh::vmap( std::vector Transpose::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) { + const std::vector& argnums, + const std::vector&) { assert(primals.size() == 1); assert(argnums.size() == 1); std::vector iaxes(axes_.size()); diff --git a/mlx/primitives.h b/mlx/primitives.h index 1fb4cf8be..2ef30e4ad 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -21,7 +21,8 @@ std::vector vjp( \ const std::vector& primals, \ const std::vector& cotangents, \ - const std::vector& argnums) override; + const std::vector& argnums, \ + const std::vector& outputs) override; #define DEFINE_PRINT(PRIMITIVE) \ void print(std::ostream& os) override { \ @@ -78,7 +79,8 @@ class Primitive { virtual std::vector vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums); + const std::vector& argnums, + const std::vector& outputs); /** * The primitive must know how to vectorize itself across @@ -464,7 +466,8 @@ class Convolution : public UnaryPrimitive { std::vector vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) override; + const std::vector& argnums, + const std::vector& outputs) override; DEFINE_PRINT(Convolution) bool is_equivalent(const Primitive& other) const override; @@ -919,7 +922,8 @@ class Matmul : public UnaryPrimitive { std::vector vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) override; + const std::vector& argnums, + const std::vector& outputs) override; DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() @@ -1153,7 +1157,8 @@ class Reduce : public UnaryPrimitive { std::vector vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums) override; + const std::vector& argnums, + const std::vector& outputs) override; void print(std::ostream& os) override { switch (reduce_type_) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 6135a54f7..7e77ee210 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -443,7 +443,7 @@ std::pair, std::vector> vjp( } } - auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums); + auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs); // Accumulate the vector-jacobian products for each input for (int i = 0; i < argnums.size(); ++i) { auto in_id = a.inputs()[argnums[i]].id(); diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 414759b5e..3d1e36269 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase): elif not isinstance(expected, mx.array): expected = mx.array(expected) self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) + else: + self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index de49d979f..776181e4f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1005,7 +1005,8 @@ class TestArray(mlx_tests.MLXTestCase): index_y = mx.array([3, 3, 1, 2]) u = mx.random.uniform(shape=(4,)) a = a.at[index_x, index_y].add(u) - self.assertEqual(a.sum().item(), u.sum().item()) + self.assertTrue(mx.allclose(a.sum(), u.sum())) + self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5) self.assertEqual(a[index_x, index_y].tolist(), u.tolist()) # Test all array.at ops