Fix slice data size (#1913)

* fix slice data size

* add test
This commit is contained in:
Awni Hannun 2025-03-02 21:50:42 -08:00 committed by GitHub
parent 5e6c130d93
commit 4e7cd31d12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -14,6 +14,10 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
} }
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides); return std::make_tuple(data_offset, inp_strides);
} }
@ -54,9 +58,10 @@ void slice(
data_end += end_idx * in.strides()[i]; data_end += end_idx * in.strides()[i];
} }
} }
// data_end can be -1 if (data_end < 0) {
size_t data_size = data_end += in.data_size();
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset); }
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out); shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
} }

View File

@ -2846,6 +2846,11 @@ class TestOps(mlx_tests.MLXTestCase):
b[::2] = 0 b[::2] = 0
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1]))) self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
def test_slice_with_negative_stride(self):
a = mx.random.uniform(shape=(128, 4))
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()