More buffer donation with no-ops (#1591)

* more donation

* fix test

* fix build
This commit is contained in:
Awni Hannun
2024-11-18 08:35:41 -08:00
committed by GitHub
parent 6931f84412
commit 9bd03dd9b4
7 changed files with 82 additions and 13 deletions

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
return move_or_copy(in, out, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
move_or_copy(in, out, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
move_or_copy(inputs[0], out);
}
void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
move_or_copy(inputs[j], outputs[i]);
}
}
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
move_or_copy(inputs[i], outputs[i]);
}
}
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
move_or_copy(in, out, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
move_or_copy(inputs[0], out);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
move_or_copy(in, out, out_strides, flags, in.data_size());
}
} // namespace mlx::core