This commit is contained in:
Awni Hannun
2025-11-13 11:30:02 -08:00
committed by GitHub
parent 8973550ff3
commit 66519fb348
6 changed files with 39 additions and 17 deletions

View File

@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
int64_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;

View File

@@ -439,7 +439,7 @@ class array {
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
int64_t offset = 0);
void copy_shared_buffer(const array& other);

View File

@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
int64_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
@@ -51,17 +47,24 @@ void slice(
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
// Get the location of the end based on the inp strides and out.shape()
int64_t low_idx = 0;
int64_t high_idx = 0;
for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
}
}
if (data_end < 0) {
data_end += in.data_size();
int64_t data_size = (high_idx - low_idx) + 1;
if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -11,7 +11,7 @@ void slice_gpu(
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
const Stream&) {
slice(in, out, start_indices, strides);
}

View File

@@ -3058,6 +3058,11 @@ class TestOps(mlx_tests.MLXTestCase):
out = a[::-1]
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
a = mx.arange(8)
for _ in range(4):
a = a[::-1]
self.assertTrue(mx.array_equal(a, mx.arange(8)))
def test_complex_ops(self):
x = mx.array(
[

View File

@@ -292,7 +292,7 @@ TEST_CASE("test slice") {
out = slice(x, {0}, {4}, {2});
eval(out);
CHECK_EQ(out.data_size(), 4);
CHECK_EQ(out.data_size(), 3);
x = ones({4, 4});
out = slice(x, {0, 0}, {2, 4});
@@ -325,6 +325,20 @@ TEST_CASE("test slice") {
out = slice(x, {2, 2, 2}, {3, 4, 3});
eval(out);
CHECK_EQ(out.data_size(), 5);
x = ones({8});
out = slice(x, {7}, {-9}, {-1});
eval(out);
CHECK_EQ(out.data_size(), 8);
out = slice(x, {7}, {-9}, {-1});
eval(out);
CHECK_EQ(out.data_size(), 8);
x = ones({4, 2});
out = slice(x, {3, 0}, {-5, 2}, {-1, 1});
eval(out);
CHECK_EQ(out.data_size(), 8);
}
TEST_CASE("test slice update") {