diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 6cd851111..499cc0ce4 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -62,6 +62,7 @@ DEFAULT(Partition) DEFAULT_MULTI(QRF) DEFAULT(RandomBits) DEFAULT(Reshape) +DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) DEFAULT(Sigmoid) @@ -292,45 +293,6 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { } } -// TODO: Avoid code duplication with the common backend. -struct RemainderFn { - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return std::fmod(numerator, denominator); - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } -}; - -void Remainder::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - - if (a.dtype() == float32) { - binary( - a, - b, - out, - RemainderFn{}, - UseDefaultBinaryOp(), - UseDefaultBinaryOp(), - [](const auto* a, const auto* b, auto* o, auto n) { - int num_el = n; - vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el); - }); - } else { - binary(a, b, out, RemainderFn{}); - } -} - void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index a51d22d0f..855e8467b 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -140,16 +140,34 @@ void Divide::eval(const std::vector& inputs, array& out) { struct RemainderFn { template - std::enable_if_t, T> operator()( + std::enable_if_t & !std::is_signed_v, T> operator()( T numerator, T denominator) { - return std::fmod(numerator, denominator); + return numerator % denominator; } template - std::enable_if_t, T> operator()( + std::enable_if_t & std::is_signed_v, T> operator()( T numerator, T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + complex64_t operator()(complex64_t numerator, complex64_t denominator) { return numerator % denominator; } }; diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 8adb84c58..006f2ff0e 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -24,20 +24,30 @@ struct Divide { struct Remainder { template - T operator()(T x, T y) { + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { return x % y; } - template <> - float operator()(float x, float y) { - return fmod(x, y); + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; } template <> - half operator()(half x, half y) { - return fmod(x, y); - } - template <> - bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { - return fmod(x, y); + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; } }; diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 3e4dbc8f2..245ced024 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -14,10 +14,29 @@ struct FloorDivide { }; struct Remainder { - template T operator()(T x, T y) { return x % y; } - template <> float operator()(float x, float y) { return fmod(x, y); } - template <> half operator()(half x, half y) { return fmod(x, y); } - template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } + template + metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } }; template diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index ac966a293..9cb27c5a3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real - (b.real * static_cast(a.real / b.real)); auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } return {real, imag}; } diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 55cbe447a..19ab1b542 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) { return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && (real < 0 != b.real() < 0)) + real += b.real(); + if (imag != 0 && (imag < 0 != b.imag() < 0)) + imag += b.imag(); + return {real, imag}; +} + inline bool operator<=(const complex64_t& a, const complex64_t& b) { return operator>=(b, a); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3d84f4b02..9ad6d5a53 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -275,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(z.dtype, dt) self.assertEqual(z.item(), 1) + z = -1 % x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), 1) + + z = -1 % -x + self.assertEqual(z.dtype, dt) + self.assertEqual(z.item(), -1) + + x = mx.arange(10).astype(dt) - 5 + y = x % 5 + z = x % -5 + self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) + self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1]) + def test_comparisons(self): a = mx.array([0.0, 1.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0])