mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
@@ -35,4 +35,29 @@ void shared_buffer_slice(
|
||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void slice(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
// data_end can be -1
|
||||
size_t data_size =
|
||||
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -11,11 +11,10 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
void slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user