Allow no copy negative strides in as_strided and slice (#1688)

* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
This commit is contained in:
Awni Hannun
2024-12-12 08:59:45 -08:00
committed by GitHub
parent 4d595a2a39
commit 6bd28d246e
15 changed files with 133 additions and 163 deletions

View File

@@ -507,34 +507,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
auto copy_needed = std::any_of(
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Strides ostrides{out.strides().begin(), out.strides().end()};
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {