mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Reshape improvement (#818)
This commit is contained in:
parent
5ad133f8bb
commit
29d0c10ee5
@ -468,20 +468,78 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays
|
||||
if (in.size() == 0) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,4 +28,70 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@ -805,13 +805,13 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,72 +96,6 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1518,9 +1518,11 @@ array min(
|
||||
|
||||
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmin(reshape(a, {size}, s), 0, false, s);
|
||||
auto result = argmin(reshape(a, {size}, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -1549,9 +1551,11 @@ array argmin(
|
||||
|
||||
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmax(reshape(a, {size}, s), 0, false, s);
|
||||
auto result = argmax(reshape(a, {size}, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -1359,6 +1359,14 @@ class Reshape : public UnaryPrimitive {
|
||||
std::vector<int> shape_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::pair<bool, std::vector<size_t>> prepare_reshape(
|
||||
const array& in,
|
||||
const array& out);
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out);
|
||||
};
|
||||
|
||||
class Reduce : public UnaryPrimitive {
|
||||
|
@ -56,6 +56,50 @@ TEST_CASE("test reshape") {
|
||||
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
|
||||
y = reshape(x, {1, 5, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
||||
|
||||
// Check that reshaping a transposed array doesn't result in a copy
|
||||
x = reshape(arange(64), {2, 4, 8});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides()[0], 32);
|
||||
CHECK_EQ(x.strides()[1], 8);
|
||||
CHECK_EQ(x.strides()[2], 1);
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 2);
|
||||
CHECK_EQ(y.strides()[2], 1);
|
||||
CHECK_EQ(y.strides()[3], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 2)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
CHECK_EQ(y.strides()[3], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
// y.strides()[3] can be anything since y.shape()[3] == 1
|
||||
CHECK_EQ(y.strides()[4], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
// y.strides()[3] can be anything since y.shape()[3] == 1
|
||||
CHECK_EQ(y.strides()[4], 8);
|
||||
// y.strides()[5] can be anything since y.shape()[5] == 1
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
}
|
||||
|
||||
TEST_CASE("test flatten") {
|
||||
|
Loading…
Reference in New Issue
Block a user