mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix round to round half-cases to even (#482)
This commit is contained in:
parent
135fd796d2
commit
90c234b7ac
@ -56,11 +56,11 @@ struct SignOp {
|
|||||||
struct RoundOp {
|
struct RoundOp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
return std::round(x);
|
return std::rint(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
complex64_t operator()(complex64_t x) {
|
complex64_t operator()(complex64_t x) {
|
||||||
return {std::round(x.real()), std::round(x.imag())};
|
return {std::rint(x.real()), std::rint(x.imag())};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -134,8 +134,8 @@ struct Negative {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T> T operator()(T x) { return metal::round(x); };
|
template <typename T> T operator()(T x) { return metal::rint(x); };
|
||||||
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
|
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
|
@ -434,14 +434,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
def test_round(self):
|
def test_round(self):
|
||||||
# float
|
# float
|
||||||
x = mx.array(
|
x = mx.array(
|
||||||
[0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
|
[0.5, -0.5, 1.5, -1.5, -21.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
|
||||||
)
|
)
|
||||||
expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf]
|
expected = [0, -0, 2, -2, -21, 20, -27, 9, 0, -np.inf, np.inf]
|
||||||
self.assertListEqual(mx.round(x).tolist(), expected)
|
self.assertListEqual(mx.round(x).tolist(), expected)
|
||||||
|
|
||||||
# complex
|
# complex
|
||||||
y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j]))
|
y = mx.round(mx.array([22.2 + 3.6j, 18.5 + 98.2j]))
|
||||||
self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j])
|
self.assertListEqual(y.tolist(), [22 + 4j, 18 + 98j])
|
||||||
|
|
||||||
# decimals
|
# decimals
|
||||||
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
|
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
|
||||||
@ -459,6 +459,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
|
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
|
||||||
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
|
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
|
||||||
|
|
||||||
|
# check round to nearest for different types
|
||||||
|
dtypes = [mx.bfloat16, mx.float16, mx.float32]
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.arange(10, dtype=dtype) - 4.5
|
||||||
|
x = mx.round(x)
|
||||||
|
self.assertEqual(
|
||||||
|
x.astype(mx.float32).tolist(),
|
||||||
|
[-4.0, -4.0, -2.0, -2.0, -0.0, 0.0, 2.0, 2.0, 4.0, 4.0],
|
||||||
|
)
|
||||||
|
|
||||||
def test_transpose_noargs(self):
|
def test_transpose_noargs(self):
|
||||||
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
||||||
|
|
||||||
|
@ -957,7 +957,7 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
// Test round
|
// Test round
|
||||||
{
|
{
|
||||||
array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6});
|
array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6});
|
||||||
CHECK(array_equal(round(x), array({1, -1, 2, -2, 2, 3})).item<bool>());
|
CHECK(array_equal(round(x), array({0, -0, 2, -2, 2, 3})).item<bool>());
|
||||||
|
|
||||||
x = array({11, 222, 32});
|
x = array({11, 222, 32});
|
||||||
CHECK(array_equal(round(x, -1), array({10, 220, 30})).item<bool>());
|
CHECK(array_equal(round(x, -1), array({10, 220, 30})).item<bool>());
|
||||||
|
Loading…
Reference in New Issue
Block a user