Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -98,29 +98,17 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
return out_shape;
}
bool is_same_shape(const std::vector<array>& arrays) {
if (arrays.empty()) {
return true;
}
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
return (a.shape() == arrays[0].shape());
});
}
int normalize_axis(int axis, int ndim) {
if (ndim <= 0) {
throw std::invalid_argument("Number of dimensions must be positive.");
}
int normalize_axis_index(
int axis,
int ndim,
const std::string& msg_prefix /* = "" */) {
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "Axis " << axis << " is out of bounds for array with " << ndim
<< " dimensions.";
msg << msg_prefix << "Axis " << axis << " is out of bounds for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (axis < 0) {
axis += ndim;
}
return axis;
return axis < 0 ? axis + ndim : axis;
}
std::ostream& operator<<(std::ostream& os, const Device& d) {
@@ -323,15 +311,6 @@ std::ostream& operator<<(std::ostream& os, const Strides& v) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
namespace env {
int get_var(const char* name, int default_value) {