Make reshape faster for row_contiguous cases (#829)

This commit is contained in:
Angelos Katharopoulos 2024-03-13 16:22:03 -07:00 committed by GitHub
parent 76c919b4ec
commit 3f8b1668c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -527,8 +527,8 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays
if (in.size() == 0) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
@ -570,18 +570,13 @@ void Reshape::shared_buffer_reshape(
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}