mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix reshape copy bug (#1253)
This commit is contained in:
committed by
GitHub
parent
bdb36c9a63
commit
03cf033f82
@@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<stride_t> strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
|
||||
Reference in New Issue
Block a user