mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Reduce implicit copies in make_array (#874)
1. Move shapes into outputs instead of copying them. 2. Pass primitive by const ref as it is always copied into outputs, which removes a copy when calling make_array.
This commit is contained in:
parent
44390bd3d0
commit
f0ae00da12
@ -48,15 +48,16 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
// For each node in |outputs|, its siblings are the other nodes.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
|
@ -180,9 +180,9 @@ class array {
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
|
@ -645,7 +645,7 @@ void compile_fuse(
|
||||
}
|
||||
}
|
||||
auto compiled_outputs = array::make_arrays(
|
||||
shapes,
|
||||
std::move(shapes),
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
old_outputs.back().primitive().stream(),
|
||||
@ -738,8 +738,8 @@ std::vector<array> compile_replace(
|
||||
shapes.push_back(o.shape());
|
||||
}
|
||||
}
|
||||
auto real_out =
|
||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||
auto real_out = array::make_arrays(
|
||||
std::move(shapes), types, a.primitive_ptr(), real_inputs);
|
||||
for (int i = 0; i < trace_out.size(); ++i) {
|
||||
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||
}
|
||||
|
@ -614,7 +614,7 @@ std::vector<array> split(
|
||||
shapes.back()[ax] = a.shape(ax) - indices.back();
|
||||
|
||||
return array::make_arrays(
|
||||
shapes,
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<Split>(to_stream(s), indices, ax),
|
||||
{a});
|
||||
@ -1949,7 +1949,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return array::make_arrays(
|
||||
{inputs[0].shape(), inputs[0].shape()},
|
||||
{inputs[0].dtype(), inputs[0].dtype()},
|
||||
std::make_unique<DivMod>(to_stream(s)),
|
||||
std::make_shared<DivMod>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
|
||||
@ -3531,7 +3531,10 @@ std::vector<array> depends(
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<Depends>(to_stream(s)),
|
||||
all_inputs);
|
||||
}
|
||||
|
||||
array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
@ -782,7 +782,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
shapes,
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
inputs);
|
||||
|
Loading…
Reference in New Issue
Block a user