mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"as_strided",
|
||||
[](const array& a,
|
||||
std::optional<std::vector<int>> shape,
|
||||
std::optional<std::vector<size_t>> strides,
|
||||
std::optional<Shape> shape,
|
||||
std::optional<Strides> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> a_shape = (shape) ? *shape : a.shape();
|
||||
std::vector<size_t> a_strides;
|
||||
auto a_shape = (shape) ? *shape : a.shape();
|
||||
Strides a_strides;
|
||||
if (strides) {
|
||||
a_strides = *strides;
|
||||
} else {
|
||||
a_strides = std::vector<size_t>(a_shape.size(), 1);
|
||||
a_strides = Strides(a_shape.size(), 1);
|
||||
for (int i = a_shape.size() - 1; i > 0; i--) {
|
||||
a_strides[i - 1] = a_shape[i] * a_strides[i];
|
||||
}
|
||||
|
Reference in New Issue
Block a user