Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -4,24 +4,22 @@
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const std::vector<int>& start_indices,
const std::vector<int>& strides) {
const Shape& start_indices,
const Shape& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
Strides inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out) {