Fix divide types + floor divide (//) (#138)

* divide types

* fix black + test
This commit is contained in:
Awni Hannun
2023-12-11 20:20:58 -08:00
committed by GitHub
parent 02de234ef0
commit 25f70d4ca4
4 changed files with 42 additions and 6 deletions

View File

@@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 0.5)
x = x.astype(mx.float16)
z = x / 4.0
self.assertEqual(z.dtype, mx.float16)
x = x.astype(mx.float16)
z = 4.0 / x
self.assertEqual(z.dtype, mx.float16)
x = mx.array(5)
y = mx.array(2)
z = x / y
self.assertEqual(z.dtype, mx.float32)
self.assertEqual(z.item(), 2.5)
z = x // y
self.assertEqual(z.dtype, mx.int32)
self.assertEqual(z.item(), 2)
def test_remainder(self):
for dt in [mx.int32, mx.float32]:
x = mx.array(2, dtype=dt)