Make ops aware of rvalues: astype/as_strided/copy/full (#895)

When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
This commit is contained in:
Cheng 2024-03-28 14:35:55 +09:00 committed by GitHub
parent aca7584635
commit bab5386306
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 43 deletions

View File

@ -158,50 +158,64 @@ array linspace(
to_stream(s));
}
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
if (dtype == a.dtype()) {
return a;
return std::move(a);
}
auto copied_shape = a.shape(); // |a| will be moved
return array(
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
std::move(copied_shape),
dtype,
std::make_shared<AsType>(to_stream(s), dtype),
{std::move(a)});
}
array as_strided(
const array& a,
array a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s /* = {} */) {
// Force the input array to be contiguous
auto x = reshape(a, {-1}, s);
auto copied_shape = shape; // |shape| will be moved
auto dtype = a.dtype(); // |a| will be moved
return array(
shape,
a.dtype(),
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
{x});
std::move(copied_shape),
dtype,
std::make_shared<AsStrided>(
to_stream(s), std::move(shape), std::move(strides), offset),
// Force the input array to be contiguous.
{reshape(std::move(a), {-1}, s)});
}
array copy(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
array copy(array a, StreamOrDevice s /* = {} */) {
auto copied_shape = a.shape(); // |a| will be moved
auto dtype = a.dtype();
return array(
std::move(copied_shape),
dtype,
std::make_shared<Copy>(to_stream(s)),
{std::move(a)});
}
array full(
const std::vector<int>& shape,
const array& vals,
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
auto copied_shape = shape; // |shape| will be moved
return array(
std::move(copied_shape),
dtype,
std::make_shared<Full>(to_stream(s)),
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
}
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s /* = {} */) {
return full(shape, vals, vals.dtype(), to_stream(s));
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
auto dtype = vals.dtype(); // |vals| will be moved
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
}
array zeros(

View File

@ -41,40 +41,33 @@ array linspace(
StreamOrDevice s = {});
/** Convert an array to the given data type. */
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
array astype(array a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
const array& a,
array a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s = {});
/** Copy another array. */
array copy(const array& a, StreamOrDevice s = {});
array copy(array a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
const std::vector<int>& shape,
const array& vals,
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s = {});
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
template <typename T>
array full(
const std::vector<int>& shape,
T val,
Dtype dtype,
StreamOrDevice s = {}) {
return full(shape, array(val, dtype), to_stream(s));
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
return full(std::move(shape), array(val, dtype), to_stream(s));
}
template <typename T>
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
return full(shape, array(val), to_stream(s));
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */

View File

@ -420,12 +420,12 @@ class AsStrided : public UnaryPrimitive {
public:
explicit AsStrided(
Stream stream,
const std::vector<int>& shape,
const std::vector<size_t>& strides,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset)
: UnaryPrimitive(stream),
shape_(shape),
strides_(strides),
shape_(std::move(shape)),
strides_(std::move(strides)),
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;