mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -98,29 +98,17 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays) {
|
||||
if (arrays.empty()) {
|
||||
return true;
|
||||
}
|
||||
return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) {
|
||||
return (a.shape() == arrays[0].shape());
|
||||
});
|
||||
}
|
||||
|
||||
int normalize_axis(int axis, int ndim) {
|
||||
if (ndim <= 0) {
|
||||
throw std::invalid_argument("Number of dimensions must be positive.");
|
||||
}
|
||||
int normalize_axis_index(
|
||||
int axis,
|
||||
int ndim,
|
||||
const std::string& msg_prefix /* = "" */) {
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "Axis " << axis << " is out of bounds for array with " << ndim
|
||||
<< " dimensions.";
|
||||
msg << msg_prefix << "Axis " << axis << " is out of bounds for array with "
|
||||
<< ndim << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += ndim;
|
||||
}
|
||||
return axis;
|
||||
return axis < 0 ? axis + ndim : axis;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
@@ -323,15 +311,6 @@ std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value) {
|
||||
|
||||
Reference in New Issue
Block a user