diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index d4051fbd3..3334fb830 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -322,6 +322,45 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { } } +// TODO: Avoid code duplication with the common backend. +struct RemainderFn { + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + return std::fmod(numerator, denominator); + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } +}; + +void Remainder::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (a.dtype() == float32) { + binary( + a, + b, + out, + RemainderFn{}, + UseDefaultBinaryOp(), + UseDefaultBinaryOp(), + [](const auto* a, const auto* b, auto* o, auto n) { + int num_el = n; + vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el); + }); + } else { + binary(a, b, out, RemainderFn{}); + } +} + void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 38f0461cc..4199e9181 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -82,6 +82,29 @@ void Divide::eval(const std::vector& inputs, array& out) { binary(a, b, out, [](auto x, auto y) { return x / y; }); } +struct RemainderFn { + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + return std::fmod(numerator, denominator); + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } +}; + +void Remainder::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + binary(a, b, out, RemainderFn{}); +} + void Equal::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (equal_nan_) { diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index a73853ed9..db0ae57bb 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -35,6 +35,7 @@ DEFAULT(Copy) DEFAULT(Cos) DEFAULT(Cosh) DEFAULT(Divide) +DEFAULT(Remainder) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index c22bfde39..e31ffcbc8 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -14,6 +14,13 @@ struct Divide { template T operator()(T x, T y) { return x / y; } }; +struct Remainder { + template T operator()(T x, T y) { return x % y; } + template <> float operator()(float x, float y) { return fmod(x, y); } + template <> half operator()(half x, half y) { return fmod(x, y); } + template <> bfloat operator()(bfloat x, bfloat y) { return fmod(x, y); } +}; + struct Equal { template bool operator()(T x, T y) { return x == y; } }; @@ -363,6 +370,7 @@ instantiate_binary_types(min, Minimum) instantiate_binary_types(mul, Multiply) instantiate_binary_types(sub, Subtract) instantiate_binary_types(pow, Power) +instantiate_binary_types(rem, Remainder) // NaNEqual only needed for floating point types with boolean output instantiate_binary_all(naneq, float16, half, bool, NaNEqual) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index 0bfa707f9..097a7e3f8 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -110,3 +110,7 @@ constexpr complex64_t operator-(complex64_t a, complex64_t b) { constexpr complex64_t operator*(complex64_t a, complex64_t b) { return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; } + +constexpr complex64_t operator%(complex64_t a, complex64_t b) { + return {fmod(a.real, b.real), fmod(a.imag, b.imag)}; +} diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 7860d7de2..79eebf4f9 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -363,6 +363,10 @@ void Divide::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "div"); } +void Remainder::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "rem"); +} + void Equal::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, equal_nan_ ? "naneq" : "eq"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 0d020c26a..ed4704465 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -30,6 +30,7 @@ NO_GPU(Copy) NO_GPU(Cos) NO_GPU(Cosh) NO_GPU(Divide) +NO_GPU(Remainder) NO_GPU(Equal) NO_GPU(Erf) NO_GPU(ErfInv) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c279b76f6..16f4bcd9a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1438,6 +1438,20 @@ array operator/(const array& a, double b) { return divide(a, array(b)); } +array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { + auto dtype = promote_types(a.dtype(), b.dtype()); + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + return array( + inputs[0].shape(), + dtype, + std::make_unique(to_stream(s)), + inputs); +} +array operator%(const array& a, const array& b) { + return remainder(a, b); +} + array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = diff --git a/mlx/ops.h b/mlx/ops.h index 552226b89..52d89ed73 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -636,6 +636,18 @@ array operator/(const array& a, const array& b); array operator/(double a, const array& b); array operator/(const array& a, double b); +/** Compute the element-wise remainder of division */ +array remainder(const array& a, const array& b, StreamOrDevice s = {}); +array operator%(const array& a, const array& b); +template +array operator%(T a, const array& b) { + return remainder(array(a), b); +} +template +array operator%(const array& a, T b) { + return remainder(a, array(b)); +} + /** Element-wise maximum between two arrays. */ array maximum(const array& a, const array& b, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4c07e14e0..58e539e7a 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -738,6 +738,53 @@ std::pair Divide::vmap( return {divide(a, b, stream()), to_ax}; } +std::vector Remainder::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + std::vector vjps; + for (auto arg : argnums) { + if (arg == 0) { + vjps.push_back(cotan); + } else { + auto x_over_y = divide(primals[0], primals[1], stream()); + // TODO: Replace with a proper floor when available + x_over_y = astype(x_over_y, int32, stream()); + vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream())); + } + } + return vjps; +} + +array Remainder::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto jvp_fun = [&](int i) { + int arg = argnums[i]; + if (arg == 0) { + return tangents[i]; + } else { + auto x_over_y = divide(primals[0], primals[1], stream()); + // TODO: Replace with a proper floor when available + x_over_y = astype(x_over_y, int32, stream()); + return negative(multiply(x_over_y, tangents[i], stream()), stream()); + } + }; + auto out = jvp_fun(0); + if (argnums.size() > 1) { + out = add(out, jvp_fun(1), stream()); + } + return out; +} + +std::pair Remainder::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {remainder(a, b, stream()), to_ax}; +} + std::pair Equal::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 3497bc262..c6ea72b3f 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -536,6 +536,25 @@ class Divide : public Primitive { void eval(const std::vector& inputs, array& out); }; +class Remainder : public Primitive { + public: + explicit Remainder(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Remainder) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Equal : public Primitive { public: explicit Equal(Stream stream, bool equal_nan = false) diff --git a/python/src/array.cpp b/python/src/array.cpp index ea580c78f..9d6c93f13 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -624,6 +624,18 @@ void init_array(py::module_& m) { return divide(to_array(v, float32), a); }, "other"_a) + .def( + "__mod__", + [](const array& a, const ScalarOrArray v) { + return remainder(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__rmod__", + [](const array& a, const ScalarOrArray v) { + return remainder(to_array(v, a.dtype()), a); + }, + "other"_a) .def( "__eq__", [](const array& a, const ScalarOrArray v) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e32c01bcf..6103186c4 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -253,6 +253,31 @@ void init_ops(py::module_& m) { Returns: array: The quotient ``a / b``. )pbdoc"); + m.def( + "remainder", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return remainder(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise remainder of division. + + Computes the remainder of dividing a with b with numpy-style + broadcasting semantics. Either or both input arrays can also be + scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The remainder of ``a // b``. + )pbdoc"); m.def( "equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4761bee3e..d7ae6228c 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase): "subtract", "multiply", "divide", + "remainder", "equal", "not_equal", "less", @@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, mx.float32) self.assertEqual(z.item(), 0.5) + def test_remainder(self): + for dt in [mx.int32, mx.float32]: + x = mx.array(2, dtype=dt) + y = mx.array(4, dtype=dt) + + z1 = mx.remainder(x, y) + z2 = mx.remainder(y, x) + self.assertEqual(z1.dtype, dt) + self.assertEqual(z1.item(), 2) + self.assertEqual(z2.item(), 0) + + z = x % 4 + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), 2) + + z = 1 % x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), 1) + def test_comparisons(self): a = mx.array([0.0, 1.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0])