mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop * Add test cases for arange * Update ops.cpp to include climits header * Fix arange * Fix formatting * Refactor * Add missing include
This commit is contained in:
19
mlx/ops.cpp
19
mlx/ops.cpp
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
@@ -73,10 +74,24 @@ array arange(
|
||||
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
|
||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||
}
|
||||
double real_size = std::ceil((stop - start) / step);
|
||||
if (std::isnan(real_size)) {
|
||||
|
||||
if (std::isinf(start) || std::isinf(stop)) {
|
||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||
}
|
||||
|
||||
// Check if start and stop specify a valid range because if not, we have to
|
||||
// return an empty array
|
||||
if (std::isinf(step) &&
|
||||
(step > 0 && start < stop || step < 0 && start > stop)) {
|
||||
return array({start}, dtype);
|
||||
}
|
||||
|
||||
double real_size = std::ceil((stop - start) / step);
|
||||
|
||||
if (real_size > INT_MAX) {
|
||||
throw std::invalid_argument("[arange] Maximum size exceeded.");
|
||||
}
|
||||
|
||||
int size = std::max(static_cast<int>(real_size), 0);
|
||||
return array(
|
||||
{size},
|
||||
|
Reference in New Issue
Block a user