Reshape improvement (#818)

This commit is contained in:
Angelos Katharopoulos 2024-03-12 17:54:31 -07:00 committed by GitHub
parent 5ad133f8bb
commit 29d0c10ee5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 199 additions and 84 deletions

View File

@ -468,20 +468,78 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays
if (in.size() == 0) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) { void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (in.flags().row_contiguous) {
// For row contiguous reshapes: auto [copy_necessary, out_strides] = prepare_reshape(in, out);
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it if (copy_necessary) {
// becomes col contiguous again.
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
} }
} }

View File

@ -28,4 +28,70 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<size_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,6 +3,7 @@
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"

View File

@ -805,13 +805,13 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) { void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (in.flags().row_contiguous) {
auto flags = in.flags(); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; if (copy_necessary) {
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
copy_gpu(in, out, CopyType::General); copy_gpu(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
} }
} }

View File

@ -96,72 +96,6 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
} }
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<size_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
} // namespace } // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@ -1518,9 +1518,11 @@ array min(
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size(); int size = a.size();
auto result = argmin(reshape(a, {size}, s), 0, false, s); auto result = argmin(reshape(a, {size}, s), 0, true, s);
if (keepdims) { if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s); result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
} else {
result = squeeze(result, s);
} }
return result; return result;
} }
@ -1549,9 +1551,11 @@ array argmin(
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size(); int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, false, s); auto result = argmax(reshape(a, {size}, s), 0, true, s);
if (keepdims) { if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s); result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
} else {
result = squeeze(result, s);
} }
return result; return result;
} }

View File

@ -1359,6 +1359,14 @@ class Reshape : public UnaryPrimitive {
std::vector<int> shape_; std::vector<int> shape_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
std::pair<bool, std::vector<size_t>> prepare_reshape(
const array& in,
const array& out);
void shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out);
}; };
class Reduce : public UnaryPrimitive { class Reduce : public UnaryPrimitive {

View File

@ -56,6 +56,50 @@ TEST_CASE("test reshape") {
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument); CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
y = reshape(x, {1, 5, 0}); y = reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0}); CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
// Check that reshaping a transposed array doesn't result in a copy
x = reshape(arange(64), {2, 4, 8});
x.eval();
CHECK_EQ(x.strides()[0], 32);
CHECK_EQ(x.strides()[1], 8);
CHECK_EQ(x.strides()[2], 1);
y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 2);
CHECK_EQ(y.strides()[2], 1);
CHECK_EQ(y.strides()[3], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 2)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
CHECK_EQ(y.strides()[3], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
// y.strides()[3] can be anything since y.shape()[3] == 1
CHECK_EQ(y.strides()[4], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
// y.strides()[3] can be anything since y.shape()[3] == 1
CHECK_EQ(y.strides()[4], 8);
// y.strides()[5] can be anything since y.shape()[5] == 1
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
} }
TEST_CASE("test flatten") { TEST_CASE("test flatten") {