mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
5e6c130d93
commit
4e7cd31d12
@ -14,6 +14,10 @@ std::tuple<int64_t, Strides> prepare_slice(
|
|||||||
data_offset += start_indices[i] * in.strides()[i];
|
data_offset += start_indices[i] * in.strides()[i];
|
||||||
inp_strides[i] = in.strides()[i] * strides[i];
|
inp_strides[i] = in.strides()[i] * strides[i];
|
||||||
}
|
}
|
||||||
|
// Normalize the offset
|
||||||
|
if (data_offset < 0) {
|
||||||
|
data_offset += in.data_size();
|
||||||
|
}
|
||||||
return std::make_tuple(data_offset, inp_strides);
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,9 +58,10 @@ void slice(
|
|||||||
data_end += end_idx * in.strides()[i];
|
data_end += end_idx * in.strides()[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// data_end can be -1
|
if (data_end < 0) {
|
||||||
size_t data_size =
|
data_end += in.data_size();
|
||||||
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
|
}
|
||||||
|
size_t data_size = (data_end - data_offset);
|
||||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2846,6 +2846,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
b[::2] = 0
|
b[::2] = 0
|
||||||
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
|
self.assertTrue(mx.array_equal(b, mx.array([0, 3, 0, 1])))
|
||||||
|
|
||||||
|
def test_slice_with_negative_stride(self):
|
||||||
|
a = mx.random.uniform(shape=(128, 4))
|
||||||
|
out = a[::-1]
|
||||||
|
self.assertTrue(mx.array_equal(out[-1, :], a[0, :]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user