mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Fix slice data size (#1394)
* fix slice data size and add tests * fix contiguous flag * simplify stride and perform copy for non-contiguous arrays * fix cpu * comment
This commit is contained in:
parent
11371fe251
commit
7cca1727af
27
mlx/array.h
27
mlx/array.h
@ -219,11 +219,23 @@ class array {
|
||||
};
|
||||
|
||||
struct Flags {
|
||||
// True if there are no gaps in the underlying data. Each item
|
||||
// True iff there are no gaps in the underlying data. Each item
|
||||
// in the underlying data buffer belongs to at least one index.
|
||||
//
|
||||
// True iff:
|
||||
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
|
||||
bool contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[-1] == 1 and
|
||||
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
|
||||
// range(ndim - 1))
|
||||
bool row_contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[0] == 1 and
|
||||
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
|
||||
// range(1, ndim))
|
||||
bool col_contiguous : 1;
|
||||
};
|
||||
|
||||
@ -291,7 +303,16 @@ class array {
|
||||
return array_desc_->flags;
|
||||
}
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
/** The size (in elements) of the underlying buffer the array points to.
|
||||
*
|
||||
* This can be different than the actual size of the array if the array has
|
||||
* been broadcast or irregularly strided. If ``first`` is the offset into
|
||||
* the data buffer of the first element of the array (i.e. the offset
|
||||
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
|
||||
* data buffer of the last element of the array (i.e. the offset
|
||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||
**/
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
}
|
||||
@ -412,8 +433,6 @@ class array {
|
||||
void* data_ptr{nullptr};
|
||||
|
||||
// The size in elements of the data buffer the array accesses
|
||||
// This can be different than the actual size of the array if it
|
||||
// has been broadcast or irregularly strided.
|
||||
size_t data_size;
|
||||
|
||||
// Contains useful meta data about the array
|
||||
|
@ -505,8 +505,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,18 +6,16 @@ namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides) {
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
@ -25,26 +23,16 @@ void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
@ -8,13 +8,14 @@ namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides);
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -135,7 +135,7 @@ template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<stride_t>& strides) {
|
||||
size_t data_size = 1;
|
||||
size_t no_broadcast_data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
@ -147,11 +147,12 @@ inline auto check_contiguity(
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
no_broadcast_data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||
return std::make_tuple(
|
||||
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -20,8 +20,8 @@ void RMSNorm::eval_gpu(
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
@ -208,8 +208,8 @@ void LayerNorm::eval_gpu(
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ namespace mlx::core {
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::vector<int> start_indices,
|
||||
std::vector<int> strides,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
@ -34,7 +34,15 @@ void slice_gpu(
|
||||
/* const Stream& s = */ s);
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,8 +9,8 @@ namespace mlx::core {
|
||||
void slice_gpu(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::vector<int> start_indices,
|
||||
std::vector<int> strides,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides,
|
||||
const Stream& s);
|
||||
|
||||
void concatenate_gpu(
|
||||
|
@ -24,8 +24,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
|
@ -615,6 +615,10 @@ inline auto normalize_slice(
|
||||
|
||||
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
|
||||
}
|
||||
// Simplify the stride if it's unused
|
||||
if (out_shape[i] == 1) {
|
||||
strides[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(has_neg_strides, out_shape);
|
||||
|
@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_slice_into_layer_norm(self):
|
||||
dim = 128
|
||||
eps = 1e-5
|
||||
x = mx.random.uniform(shape=(8, 100, 128))[:, 99:]
|
||||
rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps)
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4)
|
||||
|
||||
def test_layer_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
|
@ -228,7 +228,7 @@ TEST_CASE("test slice") {
|
||||
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
||||
|
||||
// Check contiguity preservation
|
||||
x = ones({10, 10}) * 2;
|
||||
x = ones({10, 10});
|
||||
eval(x);
|
||||
CHECK(x.flags().row_contiguous);
|
||||
CHECK(!x.flags().col_contiguous);
|
||||
@ -252,6 +252,59 @@ TEST_CASE("test slice") {
|
||||
eval(out);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
|
||||
x = ones({6, 4, 10});
|
||||
out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
|
||||
eval(out);
|
||||
CHECK(!out.flags().contiguous);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
|
||||
// Check data size correctness
|
||||
x = ones({4});
|
||||
out = slice(x, {0}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {2}, {4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {0}, {4}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 4);
|
||||
|
||||
x = ones({4, 4});
|
||||
out = slice(x, {0, 0}, {2, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
out = slice(x, {0, 0}, {1, 2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {0, 1}, {4, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 15);
|
||||
|
||||
out = slice(x, {1, 2}, {3, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 6);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {0, 0, 0}, {4, 2, 2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 54);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {2, 2, 2}, {3, 3, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 1);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {2, 2, 2}, {3, 4, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 5);
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
|
Loading…
Reference in New Issue
Block a user