From 90c234b7ac69d97c7049c99aa3949514b8fbafc4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 17 Jan 2024 15:27:23 -0800 Subject: [PATCH] Fix round to round half-cases to even (#482) --- mlx/backend/common/unary.h | 4 ++-- mlx/backend/metal/kernels/unary.metal | 4 ++-- python/tests/test_ops.py | 18 ++++++++++++++---- tests/ops_tests.cpp | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index bbf118b9b..918bae998 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -56,11 +56,11 @@ struct SignOp { struct RoundOp { template T operator()(T x) { - return std::round(x); + return std::rint(x); } complex64_t operator()(complex64_t x) { - return {std::round(x.real()), std::round(x.imag())}; + return {std::rint(x.real()), std::rint(x.imag())}; } }; diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 4de326f64..681d7707f 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -134,8 +134,8 @@ struct Negative { }; struct Round { - template T operator()(T x) { return metal::round(x); }; - template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; }; + template T operator()(T x) { return metal::rint(x); }; + template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; }; }; struct Sigmoid { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 24f82d40b..3206f1dcb 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -434,14 +434,14 @@ class TestOps(mlx_tests.MLXTestCase): def test_round(self): # float x = mx.array( - [0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf] + [0.5, -0.5, 1.5, -1.5, -21.03, 19.98, -27, 9, 0.0, -np.inf, np.inf] ) - expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf] + expected = [0, -0, 2, -2, -21, 20, -27, 9, 0, -np.inf, np.inf] self.assertListEqual(mx.round(x).tolist(), expected) # complex - y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j])) - self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j]) + y = mx.round(mx.array([22.2 + 3.6j, 18.5 + 98.2j])) + self.assertListEqual(y.tolist(), [22 + 4j, 18 + 98j]) # decimals y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0) @@ -459,6 +459,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5]))) self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47]))) + # check round to nearest for different types + dtypes = [mx.bfloat16, mx.float16, mx.float32] + for dtype in dtypes: + x = mx.arange(10, dtype=dtype) - 4.5 + x = mx.round(x) + self.assertEqual( + x.astype(mx.float32).tolist(), + [-4.0, -4.0, -2.0, -2.0, -0.0, 0.0, 2.0, 2.0, 4.0, 4.0], + ) + def test_transpose_noargs(self): x = mx.array([[0, 1, 1], [1, 0, 0]]) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 2c4348554..c17e25572 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -957,7 +957,7 @@ TEST_CASE("test arithmetic unary ops") { // Test round { array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6}); - CHECK(array_equal(round(x), array({1, -1, 2, -2, 2, 3})).item()); + CHECK(array_equal(round(x), array({0, -0, 2, -2, 2, 3})).item()); x = array({11, 222, 32}); CHECK(array_equal(round(x, -1), array({10, 220, 30})).item());