mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
* Ensure shape dimensions are within supported integer range (#566) * fix build * fix rebase bug --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
16
mlx/utils.h
16
mlx/utils.h
@@ -68,6 +68,22 @@ std::vector<int> broadcast_shapes(
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays);
|
||||
|
||||
/** Returns the shape dimension if it's within allowed range. */
|
||||
template <typename T>
|
||||
int check_shape_dim(const T dim) {
|
||||
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
|
||||
using U = std::conditional_t<is_signed, ssize_t, size_t>;
|
||||
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
|
||||
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
|
||||
|
||||
if ((is_signed && dim < min) || dim > max) {
|
||||
throw std::invalid_argument(
|
||||
"Shape dimension falls outside supported `int` range.");
|
||||
}
|
||||
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
|
Reference in New Issue
Block a user