mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created and passed to next primitive, and by making ops accepting args by value we can avoid lots of copies of temporary arrays.
This commit is contained in:
58
mlx/ops.cpp
58
mlx/ops.cpp
@@ -158,50 +158,64 @@ array linspace(
|
||||
to_stream(s));
|
||||
}
|
||||
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (dtype == a.dtype()) {
|
||||
return a;
|
||||
return std::move(a);
|
||||
}
|
||||
auto copied_shape = a.shape(); // |a| will be moved
|
||||
return array(
|
||||
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<AsType>(to_stream(s), dtype),
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array as_strided(
|
||||
const array& a,
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Force the input array to be contiguous
|
||||
auto x = reshape(a, {-1}, s);
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
auto dtype = a.dtype(); // |a| will be moved
|
||||
return array(
|
||||
shape,
|
||||
a.dtype(),
|
||||
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
|
||||
{x});
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<AsStrided>(
|
||||
to_stream(s), std::move(shape), std::move(strides), offset),
|
||||
// Force the input array to be contiguous.
|
||||
{reshape(std::move(a), {-1}, s)});
|
||||
}
|
||||
|
||||
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
|
||||
array copy(array a, StreamOrDevice s /* = {} */) {
|
||||
auto copied_shape = a.shape(); // |a| will be moved
|
||||
auto dtype = a.dtype();
|
||||
return array(
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<Copy>(to_stream(s)),
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
|
||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||
}
|
||||
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
||||
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
return array(
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<Full>(to_stream(s)),
|
||||
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
||||
}
|
||||
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return full(shape, vals, vals.dtype(), to_stream(s));
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = vals.dtype(); // |vals| will be moved
|
||||
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||
}
|
||||
|
||||
array zeros(
|
||||
|
||||
Reference in New Issue
Block a user