mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
26
mlx/utils.h
26
mlx/utils.h
@@ -64,30 +64,13 @@ Dtype result_type(const std::vector<array>& arrays);
|
||||
|
||||
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays);
|
||||
|
||||
/** Returns the shape dimension if it's within allowed range. */
|
||||
template <typename T>
|
||||
int check_shape_dim(const T dim) {
|
||||
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
|
||||
using U = std::conditional_t<is_signed, int64_t, size_t>;
|
||||
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
|
||||
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
|
||||
|
||||
if ((is_signed && dim < min) || dim > max) {
|
||||
throw std::invalid_argument(
|
||||
"Shape dimension falls outside supported `int` range.");
|
||||
}
|
||||
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the axis normalized to be in the range [0, ndim).
|
||||
* Based on numpy's normalize_axis_index. See
|
||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||
*/
|
||||
int normalize_axis(int axis, int ndim);
|
||||
int normalize_axis_index(
|
||||
int axis,
|
||||
int ndim,
|
||||
const std::string& msg_prefix = "");
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
@@ -96,7 +79,6 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v);
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user