mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop * Add test cases for arange * Update ops.cpp to include climits header * Fix arange * Fix formatting * Refactor * Add missing include
This commit is contained in:
		| @@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             a = mx.arange(0, float("inf"), float("inf")) | ||||
|         with self.assertRaises(ValueError): | ||||
|             a = mx.arange(float("inf"), 1, float("inf")) | ||||
|         with self.assertRaises(ValueError): | ||||
|             a = mx.arange(float("inf"), 1, 5) | ||||
|         with self.assertRaises(ValueError): | ||||
|             INT_MAX = 2147483647 | ||||
|             a = mx.arange(0, INT_MAX + 1, 1) | ||||
|  | ||||
|         a = mx.arange(5) | ||||
|         expected = [0, 1, 2, 3, 4] | ||||
| @@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|         self.assertEqual(a.dtype, mx.int32) | ||||
|  | ||||
|         a = mx.arange(0, 10, 100) | ||||
|         expected = [0] | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|         self.assertEqual(a.dtype, mx.int32) | ||||
|  | ||||
|         a = mx.arange(10, 0, 1) | ||||
|         expected = [] | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|  | ||||
|         a = mx.arange(10, 0, float("inf")) | ||||
|         expected = [] | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|  | ||||
|         a = mx.arange(0, 10, float("inf")) | ||||
|         expected = [0] | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|  | ||||
|         a = mx.arange(0, -10, float("-inf")) | ||||
|         expected = [0] | ||||
|         self.assertListEqual(a.tolist(), expected) | ||||
|  | ||||
|     def test_unary_ops(self): | ||||
|         def test_ops(npop, mlxop, x, y, atol): | ||||
|             r_np = npop(x) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Noah Farr
					Noah Farr