mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix a couple of slicing bugs (#1827)
* fix a few bugs * fix conv grad * speedup test * comment
This commit is contained in:
16
mlx/ops.cpp
16
mlx/ops.cpp
@@ -599,7 +599,13 @@ array expand_dims(
|
||||
namespace {
|
||||
|
||||
inline auto
|
||||
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
normalize_slice(const Shape& shape, Shape& start, Shape stop, Shape& strides) {
|
||||
// - Start indices are normalized
|
||||
// - End indices are unchanged as -1 means something different
|
||||
// pre-normalization (the end of the axis) versus post normalization (the
|
||||
// position left of 0).
|
||||
// - Any strides corresponding to singleton dimension are set to 1
|
||||
|
||||
Shape out_shape(shape.size());
|
||||
bool has_neg_strides = false;
|
||||
|
||||
@@ -624,10 +630,10 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
auto ed = e > -1 ? e : -1;
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed > st ? st : ed;
|
||||
ed = ed > st ? st : ed;
|
||||
|
||||
auto str = -strides[i];
|
||||
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
|
||||
out_shape[i] = (start[i] - ed + str - 1) / str;
|
||||
|
||||
} else {
|
||||
// Clamp to bounds
|
||||
@@ -635,9 +641,9 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
|
||||
|
||||
start[i] = st;
|
||||
stop[i] = ed < st ? st : ed;
|
||||
ed = ed < st ? st : ed;
|
||||
|
||||
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
|
||||
out_shape[i] = (ed - start[i] + strides[i] - 1) / strides[i];
|
||||
}
|
||||
// Simplify the stride if it's unused
|
||||
if (out_shape[i] == 1) {
|
||||
|
||||
Reference in New Issue
Block a user