mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
@@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
@@ -151,8 +151,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@@ -160,7 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
Strides out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@@ -183,7 +182,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@@ -249,18 +248,6 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
@@ -268,7 +255,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@@ -285,8 +272,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
|
||||
Reference in New Issue
Block a user