mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Propagate nans in binary ops (#579)
* propagate nans in binary ops * handle empty matmul * cpu minimum/maximum propagate nan * benchmark maximum * add min as well * throw on negative indices with full * verbose on linux * fix matmul for zero K
This commit is contained in:
@@ -386,6 +386,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [0, -7, 3]
|
||||
self.assertListEqual(mx.minimum(x, y).tolist(), expected)
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.minimum(a, b).item()))
|
||||
self.assertTrue(math.isnan(mx.minimum(b, a).item()))
|
||||
|
||||
def test_maximum(self):
|
||||
x = mx.array([0.0, -5, 10.0])
|
||||
y = mx.array([1.0, -7.0, 3.0])
|
||||
@@ -393,6 +398,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [1, -5, 10]
|
||||
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.maximum(a, b).item()))
|
||||
self.assertTrue(math.isnan(mx.maximum(b, a).item()))
|
||||
|
||||
def test_floor(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
|
||||
@@ -760,6 +770,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
a = mx.array([float("nan")])
|
||||
b = mx.array([0.0])
|
||||
self.assertTrue(math.isnan(mx.logaddexp(a, b).item()))
|
||||
|
||||
def test_log(self):
|
||||
a = mx.array([1, 0.5, 10, 100])
|
||||
result = mx.log(a)
|
||||
@@ -1761,6 +1775,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)
|
||||
|
||||
def test_empty_matmuls(self):
|
||||
a = mx.array([])
|
||||
b = mx.array([])
|
||||
self.assertEqual(mx.inner(a, b).item(), 0.0)
|
||||
|
||||
a = mx.zeros((10, 0))
|
||||
b = mx.zeros((0, 10))
|
||||
out = a @ b
|
||||
self.assertTrue(mx.array_equal(out, mx.zeros((10, 10))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user