Adds round op and primitive (#203)

This commit is contained in:
Angelos Katharopoulos
2023-12-18 11:32:48 -08:00
committed by GitHub
parent 477397bc98
commit 4d4af12c6f
17 changed files with 187 additions and 2 deletions

View File

@@ -372,7 +372,35 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(mx.ceil(x).tolist(), expected)
with self.assertRaises(ValueError):
mx.floor(mx.array([22 + 3j, 19 + 98j]))
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
def test_round(self):
# float
x = mx.array(
[0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
)
expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf]
self.assertListEqual(mx.round(x).tolist(), expected)
# complex
y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j]))
self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j])
# decimals
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
y1 = mx.round(mx.array([15, 122], mx.int32), decimals=-1)
y2 = mx.round(mx.array([15, 122], mx.int32), decimals=-2)
self.assertEqual(y0.dtype, mx.int32)
self.assertEqual(y1.dtype, mx.int32)
self.assertEqual(y2.dtype, mx.int32)
self.assertListEqual(y0.tolist(), [15, 122])
self.assertListEqual(y1.tolist(), [20, 120])
self.assertListEqual(y2.tolist(), [0, 100])
y1 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=1)
y2 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=2)
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
def test_transpose_noargs(self):
x = mx.array([[0, 1, 1], [1, 0, 0]])