Fix round to round half-cases to even (#482)

This commit is contained in:
Angelos Katharopoulos
2024-01-17 15:27:23 -08:00
committed by GitHub
parent 135fd796d2
commit 90c234b7ac
4 changed files with 19 additions and 9 deletions

View File

@@ -434,14 +434,14 @@ class TestOps(mlx_tests.MLXTestCase):
def test_round(self):
# float
x = mx.array(
[0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
[0.5, -0.5, 1.5, -1.5, -21.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
)
expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf]
expected = [0, -0, 2, -2, -21, 20, -27, 9, 0, -np.inf, np.inf]
self.assertListEqual(mx.round(x).tolist(), expected)
# complex
y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j]))
self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j])
y = mx.round(mx.array([22.2 + 3.6j, 18.5 + 98.2j]))
self.assertListEqual(y.tolist(), [22 + 4j, 18 + 98j])
# decimals
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
@@ -459,6 +459,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
# check round to nearest for different types
dtypes = [mx.bfloat16, mx.float16, mx.float32]
for dtype in dtypes:
x = mx.arange(10, dtype=dtype) - 4.5
x = mx.round(x)
self.assertEqual(
x.astype(mx.float32).tolist(),
[-4.0, -4.0, -2.0, -2.0, -0.0, 0.0, 2.0, 2.0, 4.0, 4.0],
)
def test_transpose_noargs(self):
x = mx.array([[0, 1, 1], [1, 0, 0]])