Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -35,4 +35,29 @@ void shared_buffer_slice(
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
}
void slice(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
// data_end can be -1
size_t data_size =
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
} // namespace mlx::core

View File

@@ -11,11 +11,10 @@ std::tuple<int64_t, Strides> prepare_slice(
const Shape& start_indices,
const Shape& strides);
void shared_buffer_slice(
void slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);
array& out,
const Shape& start_indices,
const Shape& strides);
} // namespace mlx::core