Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun 2024-09-04 19:10:43 -07:00 committed by GitHub
parent 11371fe251
commit 7cca1727af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 129 additions and 39 deletions

View File

@ -219,11 +219,23 @@ class array {
};
struct Flags {
// True if there are no gaps in the underlying data. Each item
// True iff there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index.
//
// True iff:
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
bool contiguous : 1;
// True iff:
// strides[-1] == 1 and
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
// range(ndim - 1))
bool row_contiguous : 1;
// True iff:
// strides[0] == 1 and
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
// range(1, ndim))
bool col_contiguous : 1;
};
@ -291,7 +303,16 @@ class array {
return array_desc_->flags;
}
/** The size (in elements) of the underlying buffer the array points to. */
/** The size (in elements) of the underlying buffer the array points to.
*
* This can be different than the actual size of the array if the array has
* been broadcast or irregularly strided. If ``first`` is the offset into
* the data buffer of the first element of the array (i.e. the offset
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
* data buffer of the last element of the array (i.e. the offset
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const {
return array_desc_->data_size;
}
@ -412,8 +433,6 @@ class array {
void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size;
// Contains useful meta data about the array

View File

@ -505,8 +505,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
}

View File

@ -6,18 +6,16 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides) {
const std::vector<int>& start_indices,
const std::vector<int>& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
@ -25,26 +23,16 @@ void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
flags.contiguous = (no_bsx_size == data_size);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}

View File

@ -8,13 +8,14 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides);
const std::vector<int>& start_indices,
const std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
size_t data_size,
array& out);
} // namespace mlx::core

View File

@ -135,7 +135,7 @@ template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t no_broadcast_data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
@ -147,11 +147,12 @@ inline auto check_contiguity(
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
no_broadcast_data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
return std::make_tuple(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core

View File

@ -20,8 +20,8 @@ void RMSNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
@ -208,8 +208,8 @@ void LayerNorm::eval_gpu(
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}

View File

@ -11,8 +11,8 @@ namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const std::vector<int>& start_indices,
const std::vector<int>& strides,
const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
@ -34,7 +34,15 @@ void slice_gpu(
/* const Stream& s = */ s);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
}

View File

@ -9,8 +9,8 @@ namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const std::vector<int>& start_indices,
const std::vector<int>& strides,
const Stream& s);
void concatenate_gpu(

View File

@ -24,8 +24,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}

View File

@ -615,6 +615,10 @@ inline auto normalize_slice(
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
}
// Simplify the stride if it's unused
if (out_shape[i] == 1) {
strides[i] = 1;
}
}
return std::make_pair(has_neg_strides, out_shape);

View File

@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_slice_into_layer_norm(self):
dim = 128
eps = 1e-5
x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]
rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)
rx = layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)
def test_layer_norm_grad(self):
D = 32
eps = 1e-5

View File

@ -228,7 +228,7 @@ TEST_CASE("test slice") {
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
// Check contiguity preservation
x = ones({10, 10}) * 2;
x = ones({10, 10});
eval(x);
CHECK(x.flags().row_contiguous);
CHECK(!x.flags().col_contiguous);
@ -252,6 +252,59 @@ TEST_CASE("test slice") {
eval(out);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
x = ones({6, 4, 10});
out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
eval(out);
CHECK(!out.flags().contiguous);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
// Check data size correctness
x = ones({4});
out = slice(x, {0}, {2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {2}, {4});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0}, {4}, {2});
eval(out);
CHECK_EQ(out.data_size(), 4);
x = ones({4, 4});
out = slice(x, {0, 0}, {2, 4});
eval(out);
CHECK_EQ(out.data_size(), 8);
out = slice(x, {0, 0}, {1, 2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0, 1}, {4, 4});
eval(out);
CHECK_EQ(out.data_size(), 15);
out = slice(x, {1, 2}, {3, 4});
eval(out);
CHECK_EQ(out.data_size(), 6);
x = ones({4, 4, 4});
out = slice(x, {0, 0, 0}, {4, 2, 2});
eval(out);
CHECK_EQ(out.data_size(), 54);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 3, 3});
eval(out);
CHECK_EQ(out.data_size(), 1);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 4, 3});
eval(out);
CHECK_EQ(out.data_size(), 5);
}
TEST_CASE("test slice update") {