mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop * Add test cases for arange * Update ops.cpp to include climits header * Fix arange * Fix formatting * Refactor * Add missing include
This commit is contained in:
parent
126c9869c8
commit
d729a1991b
19
mlx/ops.cpp
19
mlx/ops.cpp
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <climits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
@ -73,10 +74,24 @@ array arange(
|
|||||||
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
|
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
|
||||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||||
}
|
}
|
||||||
double real_size = std::ceil((stop - start) / step);
|
|
||||||
if (std::isnan(real_size)) {
|
if (std::isinf(start) || std::isinf(stop)) {
|
||||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if start and stop specify a valid range because if not, we have to
|
||||||
|
// return an empty array
|
||||||
|
if (std::isinf(step) &&
|
||||||
|
(step > 0 && start < stop || step < 0 && start > stop)) {
|
||||||
|
return array({start}, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
double real_size = std::ceil((stop - start) / step);
|
||||||
|
|
||||||
|
if (real_size > INT_MAX) {
|
||||||
|
throw std::invalid_argument("[arange] Maximum size exceeded.");
|
||||||
|
}
|
||||||
|
|
||||||
int size = std::max(static_cast<int>(real_size), 0);
|
int size = std::max(static_cast<int>(real_size), 0);
|
||||||
return array(
|
return array(
|
||||||
{size},
|
{size},
|
||||||
|
@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
a = mx.arange(0, float("inf"), float("inf"))
|
a = mx.arange(0, float("inf"), float("inf"))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
a = mx.arange(float("inf"), 1, float("inf"))
|
a = mx.arange(float("inf"), 1, float("inf"))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(float("inf"), 1, 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
INT_MAX = 2147483647
|
||||||
|
a = mx.arange(0, INT_MAX + 1, 1)
|
||||||
|
|
||||||
a = mx.arange(5)
|
a = mx.arange(5)
|
||||||
expected = [0, 1, 2, 3, 4]
|
expected = [0, 1, 2, 3, 4]
|
||||||
@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(a.tolist(), expected)
|
self.assertListEqual(a.tolist(), expected)
|
||||||
self.assertEqual(a.dtype, mx.int32)
|
self.assertEqual(a.dtype, mx.int32)
|
||||||
|
|
||||||
|
a = mx.arange(0, 10, 100)
|
||||||
|
expected = [0]
|
||||||
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
self.assertEqual(a.dtype, mx.int32)
|
||||||
|
|
||||||
|
a = mx.arange(10, 0, 1)
|
||||||
|
expected = []
|
||||||
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
|
||||||
|
a = mx.arange(10, 0, float("inf"))
|
||||||
|
expected = []
|
||||||
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
|
||||||
|
a = mx.arange(0, 10, float("inf"))
|
||||||
|
expected = [0]
|
||||||
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
|
||||||
|
a = mx.arange(0, -10, float("-inf"))
|
||||||
|
expected = [0]
|
||||||
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
|
||||||
def test_unary_ops(self):
|
def test_unary_ops(self):
|
||||||
def test_ops(npop, mlxop, x, y, atol):
|
def test_ops(npop, mlxop, x, y, atol):
|
||||||
r_np = npop(x)
|
r_np = npop(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user