mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add the remainder op (#85)
* Add remainder in the C++ backend * Add the python binding and test
This commit is contained in:

committed by
GitHub

parent
69a24e6a1e
commit
2b714714e1
@@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"subtract",
|
||||
"multiply",
|
||||
"divide",
|
||||
"remainder",
|
||||
"equal",
|
||||
"not_equal",
|
||||
"less",
|
||||
@@ -235,6 +236,25 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(z.dtype, mx.float32)
|
||||
self.assertEqual(z.item(), 0.5)
|
||||
|
||||
def test_remainder(self):
|
||||
for dt in [mx.int32, mx.float32]:
|
||||
x = mx.array(2, dtype=dt)
|
||||
y = mx.array(4, dtype=dt)
|
||||
|
||||
z1 = mx.remainder(x, y)
|
||||
z2 = mx.remainder(y, x)
|
||||
self.assertEqual(z1.dtype, dt)
|
||||
self.assertEqual(z1.item(), 2)
|
||||
self.assertEqual(z2.item(), 0)
|
||||
|
||||
z = x % 4
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = 1 % x
|
||||
self.assertEqual(z.dtype, dt)
|
||||
self.assertEqual(z.item(), 1)
|
||||
|
||||
def test_comparisons(self):
|
||||
a = mx.array([0.0, 1.0, 5.0])
|
||||
b = mx.array([-1.0, 2.0, 5.0])
|
||||
|
Reference in New Issue
Block a user