WIP (common)

This commit is contained in:
Ronan Collobert
2025-10-29 16:51:42 -07:00
parent 53525cba23
commit cacc3ab7fd
4 changed files with 22 additions and 21 deletions

View File

@@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// Compute the flags given the shape and strides // Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true; bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1; int64_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i]; r *= shape_[i];
@@ -60,7 +60,8 @@ void CustomTransforms::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
i < std::ssize(outputs);
i++, j++) { i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]); outputs[i].copy_shared_buffer(inputs[j]);
} }
@@ -70,7 +71,7 @@ void Depends::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < std::ssize(outputs); i++) {
outputs[i].copy_shared_buffer(inputs[i]); outputs[i].copy_shared_buffer(inputs[i]);
} }
} }
@@ -206,11 +207,11 @@ void Split::eval(
auto compute_new_flags = [](const auto& shape, auto compute_new_flags = [](const auto& shape,
const auto& strides, const auto& strides,
size_t in_data_size, int64_t in_data_size,
auto flags) { auto flags) {
size_t data_size = 1; int64_t data_size = 1;
size_t f_stride = 1; int64_t f_stride = 1;
size_t b_stride = 1; int64_t b_stride = 1;
flags.row_contiguous = true; flags.row_contiguous = true;
flags.col_contiguous = true; flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
@@ -240,7 +241,7 @@ void Split::eval(
std::vector<int> indices(1, 0); std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end()); indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < std::ssize(indices); i++) {
size_t offset = indices[i] * in.strides()[axis_]; size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags( auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags()); outputs[i].shape(), in.strides(), in.data_size(), in.flags());
@@ -254,7 +255,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
Strides strides; Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) { for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) { if (j < std::ssize(axes_) && i == axes_[j]) {
j++; j++;
} else { } else {
strides.push_back(in.strides(i)); strides.push_back(in.strides(i));
@@ -272,7 +273,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
Strides out_strides(out.ndim()); Strides out_strides(out.ndim());
auto& in = inputs[0]; auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) { for (int ax = 0; ax < std::ssize(axes_); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]]; out_strides[ax] = in.strides()[axes_[ax]];
} }

View File

@@ -120,7 +120,7 @@ void compiled_allocate_outputs(
Strides strides; Strides strides;
size_t data_size; size_t data_size;
array::Flags flags; array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
auto& in = inputs[i]; auto& in = inputs[i];
// Conditions for donation // Conditions for donation
// - Correct size // - Correct size
@@ -138,7 +138,7 @@ void compiled_allocate_outputs(
data_size = in.data_size(); data_size = in.data_size();
} }
} }
for (; o < outputs.size(); ++o) { for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data( outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()), allocator::malloc(data_size * outputs[o].itemsize()),
data_size, data_size,
@@ -147,7 +147,7 @@ void compiled_allocate_outputs(
} }
} else { } else {
int o = 0; int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
auto& in = inputs[i]; auto& in = inputs[i];
// Conditions for donation // Conditions for donation
// - Row contiguous // - Row contiguous
@@ -162,7 +162,7 @@ void compiled_allocate_outputs(
o++; o++;
} }
} }
for (; o < outputs.size(); ++o) { for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
} }
} }
@@ -193,7 +193,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
// Broadcast the inputs to the output shape. // Broadcast the inputs to the output shape.
Strides xstrides; Strides xstrides;
size_t j = 0; int j = 0;
for (; j < shape.size() - x.ndim(); ++j) { for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) { if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]); xstrides.push_back(out.strides()[j]);
@@ -201,7 +201,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
xstrides.push_back(0); xstrides.push_back(0);
} }
} }
for (size_t i = 0; i < x.ndim(); ++i, ++j) { for (int i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) { if (x.shape(i) == 1) {
if (shape[j] == 1) { if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]); xstrides.push_back(out.strides()[j]);
@@ -224,13 +224,13 @@ bool compiled_use_large_index(
const std::vector<array>& outputs, const std::vector<array>& outputs,
bool contiguous) { bool contiguous) {
if (contiguous) { if (contiguous) {
size_t max_size = 0; int64_t max_size = 0;
for (const auto& in : inputs) { for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size()); max_size = std::max(max_size, in.data_size());
} }
return max_size > UINT32_MAX; return max_size > UINT32_MAX;
} else { } else {
size_t max_size = 0; int64_t max_size = 0;
for (const auto& o : outputs) { for (const auto& o : outputs) {
max_size = std::max(max_size, o.size()); max_size = std::max(max_size, o.size());
} }

View File

@@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core { namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) { void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(), auto read_task = [out_ptr = out.data<char>(),
size = out.size(), size = out.size(),

View File

@@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
} }
inline bool is_donatable(const array& in, const array& out) { inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384; constexpr int64_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() && return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra; in.buffer_size() <= out.nbytes() + donation_extra;