Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -86,7 +86,7 @@ void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
slice(inputs[0], out, start_indices_, strides_);
}
void Split::eval_cpu(
const std::vector<array>& inputs,
@@ -262,29 +262,6 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
@@ -355,7 +332,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_inplace(