mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 00:01:17 +08:00
Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
b57bd0488d
commit
b670485185
@ -62,6 +62,7 @@ DEFAULT(Partition)
|
|||||||
DEFAULT_MULTI(QRF)
|
DEFAULT_MULTI(QRF)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Remainder)
|
||||||
DEFAULT(Round)
|
DEFAULT(Round)
|
||||||
DEFAULT(Scatter)
|
DEFAULT(Scatter)
|
||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
@ -292,45 +293,6 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Avoid code duplication with the common backend.
|
|
||||||
struct RemainderFn {
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
|
||||||
T numerator,
|
|
||||||
T denominator) {
|
|
||||||
return std::fmod(numerator, denominator);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
|
||||||
T numerator,
|
|
||||||
T denominator) {
|
|
||||||
return numerator % denominator;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
|
||||||
binary(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
RemainderFn{},
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
int num_el = n;
|
|
||||||
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
binary(a, b, out, RemainderFn{});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
struct RemainderFn {
|
struct RemainderFn {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||||
T numerator,
|
T numerator,
|
||||||
T denominator) {
|
T denominator) {
|
||||||
return std::fmod(numerator, denominator);
|
return numerator % denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||||
T numerator,
|
T numerator,
|
||||||
T denominator) {
|
T denominator) {
|
||||||
|
auto r = numerator % denominator;
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0))
|
||||||
|
r += denominator;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||||
|
T numerator,
|
||||||
|
T denominator) {
|
||||||
|
auto r = std::fmod(numerator, denominator);
|
||||||
|
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||||
|
r += denominator;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||||
return numerator % denominator;
|
return numerator % denominator;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -24,20 +24,30 @@ struct Divide {
|
|||||||
|
|
||||||
struct Remainder {
|
struct Remainder {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x, T y) {
|
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
return x % y;
|
return x % y;
|
||||||
}
|
}
|
||||||
template <>
|
template <typename T>
|
||||||
float operator()(float x, float y) {
|
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||||
return fmod(x, y);
|
operator()(T x, T y) {
|
||||||
|
auto r = x % y;
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
T r = fmod(x, y);
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
half operator()(half x, half y) {
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
return fmod(x, y);
|
return x % y;
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
|
||||||
return fmod(x, y);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,10 +14,29 @@ struct FloorDivide {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Remainder {
|
struct Remainder {
|
||||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
template <typename T>
|
||||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
return x % y;
|
||||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
}
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||||
|
auto r = x % y;
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
T r = fmod(x, y);
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x % y;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
template <typename T, typename U, typename Op1, typename Op2>
|
||||||
|
@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|||||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||||
|
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||||
|
real += b.real;
|
||||||
|
}
|
||||||
|
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||||
|
imag += b.imag;
|
||||||
|
}
|
||||||
return {real, imag};
|
return {real, imag};
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
|||||||
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline complex64_t operator%(complex64_t a, complex64_t b) {
|
||||||
|
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
|
||||||
|
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
|
||||||
|
if (real != 0 && (real < 0 != b.real() < 0))
|
||||||
|
real += b.real();
|
||||||
|
if (imag != 0 && (imag < 0 != b.imag() < 0))
|
||||||
|
imag += b.imag();
|
||||||
|
return {real, imag};
|
||||||
|
}
|
||||||
|
|
||||||
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
||||||
return operator>=(b, a);
|
return operator>=(b, a);
|
||||||
}
|
}
|
||||||
|
@ -275,6 +275,20 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(z.dtype, dt)
|
self.assertEqual(z.dtype, dt)
|
||||||
self.assertEqual(z.item(), 1)
|
self.assertEqual(z.item(), 1)
|
||||||
|
|
||||||
|
z = -1 % x
|
||||||
|
self.assertEqual(z.dtype, dt)
|
||||||
|
self.assertEqual(z.item(), 1)
|
||||||
|
|
||||||
|
z = -1 % -x
|
||||||
|
self.assertEqual(z.dtype, dt)
|
||||||
|
self.assertEqual(z.item(), -1)
|
||||||
|
|
||||||
|
x = mx.arange(10).astype(dt) - 5
|
||||||
|
y = x % 5
|
||||||
|
z = x % -5
|
||||||
|
self.assertEqual(y.tolist(), [0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
|
||||||
|
self.assertEqual(z.tolist(), [0, -4, -3, -2, -1, 0, -4, -3, -2, -1])
|
||||||
|
|
||||||
def test_comparisons(self):
|
def test_comparisons(self):
|
||||||
a = mx.array([0.0, 1.0, 5.0])
|
a = mx.array([0.0, 1.0, 5.0])
|
||||||
b = mx.array([-1.0, 2.0, 5.0])
|
b = mx.array([-1.0, 2.0, 5.0])
|
||||||
|
Loading…
Reference in New Issue
Block a user