mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix slice (#2758)
This commit is contained in:
@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
int64_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
|
||||
@@ -439,7 +439,7 @@ class array {
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
int64_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
|
||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
}
|
||||
// Normalize the offset
|
||||
if (data_offset < 0) {
|
||||
data_offset += in.data_size();
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
int64_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
@@ -51,17 +47,24 @@ void slice(
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
|
||||
// Get the location of the end based on the inp strides and out.shape()
|
||||
int64_t low_idx = 0;
|
||||
int64_t high_idx = 0;
|
||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||
if (inp_strides[i] > 0) {
|
||||
high_idx += delta;
|
||||
} else {
|
||||
low_idx += delta;
|
||||
}
|
||||
}
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
int64_t data_size = (high_idx - low_idx) + 1;
|
||||
if (data_size < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
const Stream&) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user