Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
@@ -151,8 +151,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
return {false, Strides(out.ndim(), 0)};
}
// Firstly let's collapse all the contiguous dimensions of the input
@@ -160,7 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
Strides out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
@@ -183,7 +182,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
@@ -249,18 +248,6 @@ void Split::eval(
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
@@ -268,7 +255,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
@@ -285,8 +272,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {