Reduce implicit copies in make_array (#874)

1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
   removes a copy when calling make_array.
This commit is contained in:
Cheng 2024-03-22 22:29:16 +09:00 committed by GitHub
parent 44390bd3d0
commit f0ae00da12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 14 deletions

View File

@ -48,15 +48,16 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
for (int i = 0; i < outputs.size(); ++i) {
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);

View File

@ -180,9 +180,9 @@ class array {
std::vector<array> inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */

View File

@ -645,7 +645,7 @@ void compile_fuse(
}
}
auto compiled_outputs = array::make_arrays(
shapes,
std::move(shapes),
types,
std::make_shared<Compiled>(
old_outputs.back().primitive().stream(),
@ -738,8 +738,8 @@ std::vector<array> compile_replace(
shapes.push_back(o.shape());
}
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
auto real_out = array::make_arrays(
std::move(shapes), types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
}

View File

@ -614,7 +614,7 @@ std::vector<array> split(
shapes.back()[ax] = a.shape(ax) - indices.back();
return array::make_arrays(
shapes,
std::move(shapes),
dtypes,
std::make_shared<Split>(to_stream(s), indices, ax),
{a});
@ -1949,7 +1949,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()},
std::make_unique<DivMod>(to_stream(s)),
std::make_shared<DivMod>(to_stream(s)),
inputs);
}
@ -3531,7 +3531,10 @@ std::vector<array> depends(
}
return array::make_arrays(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
std::move(shapes),
dtypes,
std::make_shared<Depends>(to_stream(s)),
all_inputs);
}
array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {

View File

@ -782,7 +782,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
}
return array::make_arrays(
shapes,
std::move(shapes),
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
inputs);