Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -599,7 +599,13 @@ array expand_dims(
namespace {
inline auto
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
normalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) {
// - Start indices are normalized
// - End indices are unchanged as -1 means something different
// pre-normalization (the end of the axis) versus post normalization (the
// position left of 0).
// - Any strides corresponding to singleton dimension are set to 1
Shape out_shape(shape.size());
bool has_neg_strides = false;
@@ -624,10 +630,10 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
auto ed = e > -1 ? e : -1;
start[i] = st;
stop[i] = ed > st ? st : ed;
ed = ed > st ? st : ed;
auto str = -strides[i];
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
out_shape[i] = (start[i] - ed + str - 1) / str;
} else {
// Clamp to bounds
@@ -635,9 +641,9 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
ed = ed < st ? st : ed;
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i];
}
// Simplify the stride if it's unused
if (out_shape[i] == 1) {