mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Minor updates to address a few issues (#537)
* docs on arg indices return type * arange with nan * undo isort
This commit is contained in:
parent
4fe2fa2a64
commit
f30e63353a
@ -79,7 +79,14 @@ array arange(
|
|||||||
msg << bool_ << " not supported for arange.";
|
msg << bool_ << " not supported for arange.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
int size = std::max(static_cast<int>(std::ceil((stop - start) / step)), 0);
|
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
|
||||||
|
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||||
|
}
|
||||||
|
double real_size = std::ceil((stop - start) / step);
|
||||||
|
if (std::isnan(real_size)) {
|
||||||
|
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||||
|
}
|
||||||
|
int size = std::max(static_cast<int>(real_size), 0);
|
||||||
return array(
|
return array(
|
||||||
{size},
|
{size},
|
||||||
dtype,
|
dtype,
|
||||||
|
@ -2254,7 +2254,7 @@ void init_ops(py::module_& m) {
|
|||||||
singleton dimensions, defaults to `False`.
|
singleton dimensions, defaults to `False`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array with the indices of the minimum values.
|
array: The ``uint32`` array with the indices of the minimum values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"argmax",
|
"argmax",
|
||||||
@ -2287,7 +2287,7 @@ void init_ops(py::module_& m) {
|
|||||||
singleton dimensions, defaults to `False`.
|
singleton dimensions, defaults to `False`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The output array with the indices of the maximum values.
|
array: The ``uint32`` array with the indices of the maximum values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"sort",
|
"sort",
|
||||||
@ -2343,7 +2343,7 @@ void init_ops(py::module_& m) {
|
|||||||
If unspecified, it defaults to -1 (sorting over the last axis).
|
If unspecified, it defaults to -1 (sorting over the last axis).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The indices that sort the input array.
|
array: The ``uint32`` array containing indices that sort the input.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"partition",
|
"partition",
|
||||||
@ -2416,7 +2416,7 @@ void init_ops(py::module_& m) {
|
|||||||
If unspecified, it defaults to ``-1``.
|
If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The indices that partition the input array.
|
array: The `uint32`` array containing indices that partition the input.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"topk",
|
"topk",
|
||||||
|
@ -980,6 +980,17 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(z.tolist(), [5, 6, 7])
|
self.assertEqual(z.tolist(), [5, 6, 7])
|
||||||
|
|
||||||
def test_arange_overload_dispatch(self):
|
def test_arange_overload_dispatch(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(float("nan"), 1, 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(0, float("nan"), 5)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(0, 2, float("nan"))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(0, float("inf"), float("inf"))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a = mx.arange(float("inf"), 1, float("inf"))
|
||||||
|
|
||||||
a = mx.arange(5)
|
a = mx.arange(5)
|
||||||
expected = [0, 1, 2, 3, 4]
|
expected = [0, 1, 2, 3, 4]
|
||||||
self.assertListEqual(a.tolist(), expected)
|
self.assertListEqual(a.tolist(), expected)
|
||||||
|
Loading…
Reference in New Issue
Block a user