Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun 2024-09-04 19:10:43 -07:00 committed by GitHub
parent 11371fe251
commit 7cca1727af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 129 additions and 39 deletions

View File

@ -219,11 +219,23 @@ class array {
}; };
struct Flags { struct Flags {
// True if there are no gaps in the underlying data. Each item // True iff there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index. // in the underlying data buffer belongs to at least one index.
//
// True iff:
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
bool contiguous : 1; bool contiguous : 1;
// True iff:
// strides[-1] == 1 and
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
// range(ndim - 1))
bool row_contiguous : 1; bool row_contiguous : 1;
// True iff:
// strides[0] == 1 and
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
// range(1, ndim))
bool col_contiguous : 1; bool col_contiguous : 1;
}; };
@ -291,7 +303,16 @@ class array {
return array_desc_->flags; return array_desc_->flags;
} }
/** The size (in elements) of the underlying buffer the array points to. */ /** The size (in elements) of the underlying buffer the array points to.
*
* This can be different than the actual size of the array if the array has
* been broadcast or irregularly strided. If ``first`` is the offset into
* the data buffer of the first element of the array (i.e. the offset
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
* data buffer of the last element of the array (i.e. the offset
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const { size_t data_size() const {
return array_desc_->data_size; return array_desc_->data_size;
} }
@ -412,8 +433,6 @@ class array {
void* data_ptr{nullptr}; void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses // The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size; size_t data_size;
// Contains useful meta data about the array // Contains useful meta data about the array

View File

@ -505,8 +505,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
/* int64_t o_offset = */ 0, /* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General); /* CopyType ctype = */ CopyType::General);
} else { } else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()}; std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out); shared_buffer_slice(in, ostrides, data_offset, data_size, out);
} }
} }

View File

@ -6,18 +6,16 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in, const array& in,
std::vector<int>& start_indices, const std::vector<int>& start_indices,
std::vector<int>& strides) { const std::vector<int>& strides) {
int64_t data_offset = 0; int64_t data_offset = 0;
bool copy_needed = false; bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0); std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0; copy_needed |= strides[i] < 0;
} }
return std::make_tuple(copy_needed, data_offset, inp_strides); return std::make_tuple(copy_needed, data_offset, inp_strides);
} }
@ -25,26 +23,16 @@ void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const std::vector<size_t>& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size,
array& out) { array& out) {
// Compute row/col contiguity // Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] = auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides); check_contiguity(out.shape(), out_strides);
auto flags = in.flags(); auto flags = in.flags();
flags.row_contiguous = is_row_contiguous; flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous; flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
} }

View File

@ -8,13 +8,14 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in, const array& in,
std::vector<int>& start_indices, const std::vector<int>& start_indices,
std::vector<int>& strides); const std::vector<int>& strides);
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const std::vector<size_t>& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size,
array& out); array& out);
} // namespace mlx::core } // namespace mlx::core

View File

@ -135,7 +135,7 @@ template <typename stride_t>
inline auto check_contiguity( inline auto check_contiguity(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<stride_t>& strides) { const std::vector<stride_t>& strides) {
size_t data_size = 1; size_t no_broadcast_data_size = 1;
size_t f_stride = 1; size_t f_stride = 1;
size_t b_stride = 1; size_t b_stride = 1;
bool is_row_contiguous = true; bool is_row_contiguous = true;
@ -147,11 +147,12 @@ inline auto check_contiguity(
f_stride *= shape[i]; f_stride *= shape[i];
b_stride *= shape[ri]; b_stride *= shape[ri];
if (strides[i] > 0) { if (strides[i] > 0) {
data_size *= shape[i]; no_broadcast_data_size *= shape[i];
} }
} }
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous); return std::make_tuple(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -20,8 +20,8 @@ void RMSNorm::eval_gpu(
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
std::vector<array> copies; std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& { auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }
@ -208,8 +208,8 @@ void LayerNorm::eval_gpu(
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
std::vector<array> copies; std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& { auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }

View File

@ -11,8 +11,8 @@ namespace mlx::core {
void slice_gpu( void slice_gpu(
const array& in, const array& in,
array& out, array& out,
std::vector<int> start_indices, const std::vector<int>& start_indices,
std::vector<int> strides, const std::vector<int>& strides,
const Stream& s) { const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = auto [copy_needed, data_offset, inp_strides] =
@ -34,7 +34,15 @@ void slice_gpu(
/* const Stream& s = */ s); /* const Stream& s = */ s);
} else { } else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()}; std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out); size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
} }
} }

View File

@ -9,8 +9,8 @@ namespace mlx::core {
void slice_gpu( void slice_gpu(
const array& in, const array& in,
array& out, array& out,
std::vector<int> start_indices, const std::vector<int>& start_indices,
std::vector<int> strides, const std::vector<int>& strides,
const Stream& s); const Stream& s);
void concatenate_gpu( void concatenate_gpu(

View File

@ -24,8 +24,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
std::vector<array> copies; std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& { auto check_input = [&copies, &s](const array& x) -> const array& {
bool no_copy = x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back());
} }

View File

@ -615,6 +615,10 @@ inline auto normalize_slice(
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
} }
// Simplify the stride if it's unused
if (out_shape[i] == 1) {
strides[i] = 1;
}
} }
return std::make_pair(has_neg_strides, out_shape); return std::make_pair(has_neg_strides, out_shape);

View File

@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.layer_norm(x, None, None, eps) rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_slice_into_layer_norm(self):
dim = 128
eps = 1e-5
x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]
rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)
rx = layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)
def test_layer_norm_grad(self): def test_layer_norm_grad(self):
D = 32 D = 32
eps = 1e-5 eps = 1e-5

View File

@ -228,7 +228,7 @@ TEST_CASE("test slice") {
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>()); CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
// Check contiguity preservation // Check contiguity preservation
x = ones({10, 10}) * 2; x = ones({10, 10});
eval(x); eval(x);
CHECK(x.flags().row_contiguous); CHECK(x.flags().row_contiguous);
CHECK(!x.flags().col_contiguous); CHECK(!x.flags().col_contiguous);
@ -252,6 +252,59 @@ TEST_CASE("test slice") {
eval(out); eval(out);
CHECK(!out.flags().row_contiguous); CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous); CHECK(!out.flags().col_contiguous);
x = ones({6, 4, 10});
out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
eval(out);
CHECK(!out.flags().contiguous);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
// Check data size correctness
x = ones({4});
out = slice(x, {0}, {2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {2}, {4});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0}, {4}, {2});
eval(out);
CHECK_EQ(out.data_size(), 4);
x = ones({4, 4});
out = slice(x, {0, 0}, {2, 4});
eval(out);
CHECK_EQ(out.data_size(), 8);
out = slice(x, {0, 0}, {1, 2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0, 1}, {4, 4});
eval(out);
CHECK_EQ(out.data_size(), 15);
out = slice(x, {1, 2}, {3, 4});
eval(out);
CHECK_EQ(out.data_size(), 6);
x = ones({4, 4, 4});
out = slice(x, {0, 0, 0}, {4, 2, 2});
eval(out);
CHECK_EQ(out.data_size(), 54);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 3, 3});
eval(out);
CHECK_EQ(out.data_size(), 1);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 4, 3});
eval(out);
CHECK_EQ(out.data_size(), 5);
} }
TEST_CASE("test slice update") { TEST_CASE("test slice update") {