mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Fix divide types + floor divide (//) (#138)
* divide types * fix black + test
This commit is contained in:
@@ -236,6 +236,24 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(z.dtype, mx.float32)
|
||||
self.assertEqual(z.item(), 0.5)
|
||||
|
||||
x = x.astype(mx.float16)
|
||||
z = x / 4.0
|
||||
self.assertEqual(z.dtype, mx.float16)
|
||||
|
||||
x = x.astype(mx.float16)
|
||||
z = 4.0 / x
|
||||
self.assertEqual(z.dtype, mx.float16)
|
||||
|
||||
x = mx.array(5)
|
||||
y = mx.array(2)
|
||||
z = x / y
|
||||
self.assertEqual(z.dtype, mx.float32)
|
||||
self.assertEqual(z.item(), 2.5)
|
||||
|
||||
z = x // y
|
||||
self.assertEqual(z.dtype, mx.int32)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
def test_remainder(self):
|
||||
for dt in [mx.int32, mx.float32]:
|
||||
x = mx.array(2, dtype=dt)
|
||||
|
Reference in New Issue
Block a user