Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun
2024-09-04 19:10:43 -07:00
committed by GitHub
parent 11371fe251
commit 7cca1727af
12 changed files with 129 additions and 39 deletions

View File

@@ -20,8 +20,8 @@ void RMSNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
@@ -208,8 +208,8 @@ void LayerNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}

View File

@@ -11,8 +11,8 @@ namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const std::vector<int>& start_indices,
const std::vector<int>& strides,
const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
@@ -34,7 +34,15 @@ void slice_gpu(
/* const Stream& s = */ s);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
}

View File

@@ -9,8 +9,8 @@ namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const std::vector<int>& start_indices,
const std::vector<int>& strides,
const Stream& s);
void concatenate_gpu(

View File

@@ -24,8 +24,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}