Remainder negative numerator bug fixed (#641)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Abdussamet Türker 2024-02-10 03:49:14 +03:00 committed by GitHub
parent b57bd0488d
commit b670485185
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 95 additions and 56 deletions

View File

@ -62,6 +62,7 @@ DEFAULT(Partition)
DEFAULT_MULTI(QRF) DEFAULT_MULTI(QRF)
DEFAULT(RandomBits) DEFAULT(RandomBits)
DEFAULT(Reshape) DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round) DEFAULT(Round)
DEFAULT(Scatter) DEFAULT(Scatter)
DEFAULT(Sigmoid) DEFAULT(Sigmoid)
@ -292,45 +293,6 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) { void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];

View File

@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
struct RemainderFn { struct RemainderFn {
template <typename T> template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()( std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator, T numerator,
T denominator) { T denominator) {
return std::fmod(numerator, denominator); return numerator % denominator;
} }
template <typename T> template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()( std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator, T numerator,
T denominator) { T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator; return numerator % denominator;
} }
}; };

View File

@ -24,20 +24,30 @@ struct Divide {
struct Remainder { struct Remainder {
template <typename T> template <typename T>
T operator()(T x, T y) { metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y; return x % y;
} }
template <> template <typename T>
float operator()(float x, float y) { metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
return fmod(x, y); operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
} }
template <> template <>
half operator()(half x, half y) { complex64_t operator()(complex64_t x, complex64_t y) {
return fmod(x, y); return x % y;
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return fmod(x, y);
} }
}; };

View File

@ -14,10 +14,29 @@ struct FloorDivide {
}; };
struct Remainder { struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; } template <typename T>
template <> float operator()(float x, float y) { return fmod(x, y); } metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
template <> half operator()(half x, half y) { return fmod(x, y); } return x % y;
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } }
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
}; };
template <typename T, typename U, typename Op1, typename Op2> template <typename T, typename U, typename Op1, typename Op2>

View File

@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
constexpr complex64_t operator%(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real)); auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag)); auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag}; return {real, imag};
} }

View File

@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
} }
inline complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
if (real != 0 && (real < 0 != b.real() < 0))
real += b.real();
if (imag != 0 && (imag < 0 != b.imag() < 0))
imag += b.imag();
return {real, imag};
}
inline bool operator<=(const complex64_t& a, const complex64_t& b) { inline bool operator<=(const complex64_t& a, const complex64_t& b) {
return operator>=(b, a); return operator>=(b, a);
} }

View File

@ -275,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, dt) self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1) self.assertEqual(z.item(), 1)
z = -1 % x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), 1)
z = -1 % -x
self.assertEqual(z.dtype, dt)
self.assertEqual(z.item(), -1)
x = mx.arange(10).astype(dt) - 5
y = x % 5
z = x % -5
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
def test_comparisons(self): def test_comparisons(self):
a = mx.array([0.0, 1.0, 5.0]) a = mx.array([0.0, 1.0, 5.0])
b = mx.array([-1.0, 2.0, 5.0]) b = mx.array([-1.0, 2.0, 5.0])