mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop * Add test cases for arange * Update ops.cpp to include climits header * Fix arange * Fix formatting * Refactor * Add missing include
This commit is contained in:
@@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.arange(0, float("inf"), float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
INT_MAX = 2147483647
|
||||
a = mx.arange(0, INT_MAX + 1, 1)
|
||||
|
||||
a = mx.arange(5)
|
||||
expected = [0, 1, 2, 3, 4]
|
||||
@@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(0, 10, 100)
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(10, 0, 1)
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(10, 0, float("inf"))
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, 10, float("inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, -10, float("-inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x)
|
||||
|
Reference in New Issue
Block a user