mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
fix slice (#2758)
This commit is contained in:
@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
int64_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
|
||||
@@ -439,7 +439,7 @@ class array {
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
int64_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
|
||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
}
|
||||
// Normalize the offset
|
||||
if (data_offset < 0) {
|
||||
data_offset += in.data_size();
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
int64_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
@@ -51,17 +47,24 @@ void slice(
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
|
||||
// Get the location of the end based on the inp strides and out.shape()
|
||||
int64_t low_idx = 0;
|
||||
int64_t high_idx = 0;
|
||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||
if (inp_strides[i] > 0) {
|
||||
high_idx += delta;
|
||||
} else {
|
||||
low_idx += delta;
|
||||
}
|
||||
}
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
int64_t data_size = (high_idx - low_idx) + 1;
|
||||
if (data_size < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
const Stream&) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
|
||||
@@ -3058,6 +3058,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = a[::-1]
|
||||
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||
|
||||
a = mx.arange(8)
|
||||
for _ in range(4):
|
||||
a = a[::-1]
|
||||
self.assertTrue(mx.array_equal(a, mx.arange(8)))
|
||||
|
||||
def test_complex_ops(self):
|
||||
x = mx.array(
|
||||
[
|
||||
|
||||
@@ -292,7 +292,7 @@ TEST_CASE("test slice") {
|
||||
|
||||
out = slice(x, {0}, {4}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 4);
|
||||
CHECK_EQ(out.data_size(), 3);
|
||||
|
||||
x = ones({4, 4});
|
||||
out = slice(x, {0, 0}, {2, 4});
|
||||
@@ -325,6 +325,20 @@ TEST_CASE("test slice") {
|
||||
out = slice(x, {2, 2, 2}, {3, 4, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 5);
|
||||
|
||||
x = ones({8});
|
||||
out = slice(x, {7}, {-9}, {-1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
out = slice(x, {7}, {-9}, {-1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
x = ones({4, 2});
|
||||
out = slice(x, {3, 0}, {-5, 2}, {-1, 1});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
|
||||
Reference in New Issue
Block a user