Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -64,30 +64,13 @@ Dtype result_type(const std::vector<array>& arrays);
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
bool is_same_shape(const std::vector<array>& arrays);
/** Returns the shape dimension if it's within allowed range. */
template <typename T>
int check_shape_dim(const T dim) {
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
using U = std::conditional_t<is_signed, int64_t, size_t>;
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
if ((is_signed && dim < min) || dim > max) {
throw std::invalid_argument(
"Shape dimension falls outside supported `int` range.");
}
return static_cast<int>(dim);
}
/**
* Returns the axis normalized to be in the range [0, ndim).
* Based on numpy's normalize_axis_index. See
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
*/
int normalize_axis(int axis, int ndim);
int normalize_axis_index(
int axis,
int ndim,
const std::string& msg_prefix = "");
std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
@@ -96,7 +79,6 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const Shape& v);
std::ostream& operator<<(std::ostream& os, const Strides& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}