mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -381,7 +381,7 @@ array batch_tensordot(
|
||||
size2 *= x.shape(s);
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
for (auto ax : i) {
|
||||
shape.push_back(x.shape(ax));
|
||||
}
|
||||
@@ -391,7 +391,7 @@ array batch_tensordot(
|
||||
return reshape(transpose(x, reorder, s), std::move(shape), s);
|
||||
};
|
||||
|
||||
std::vector<int> out_shape;
|
||||
Shape out_shape;
|
||||
for (auto ax : a_batch) {
|
||||
out_shape.push_back(a.shape(ax));
|
||||
}
|
||||
@@ -455,7 +455,7 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
|
||||
axes.push_back(i);
|
||||
}
|
||||
}
|
||||
std::vector<int> idx_shape(n_expand--, 1);
|
||||
Shape idx_shape(n_expand--, 1);
|
||||
idx_shape[0] = in.shape(axes.back());
|
||||
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
|
||||
for (int i = 0; i < v; ++i) {
|
||||
|
Reference in New Issue
Block a user