mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 05:31:18 +08:00
Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created and passed to next primitive, and by making ops accepting args by value we can avoid lots of copies of temporary arrays.
This commit is contained in:
parent
aca7584635
commit
bab5386306
58
mlx/ops.cpp
58
mlx/ops.cpp
@ -158,50 +158,64 @@ array linspace(
|
|||||||
to_stream(s));
|
to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
if (dtype == a.dtype()) {
|
if (dtype == a.dtype()) {
|
||||||
return a;
|
return std::move(a);
|
||||||
}
|
}
|
||||||
|
auto copied_shape = a.shape(); // |a| will be moved
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
|
std::move(copied_shape),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<AsType>(to_stream(s), dtype),
|
||||||
|
{std::move(a)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array as_strided(
|
array as_strided(
|
||||||
const array& a,
|
array a,
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
std::vector<size_t> strides,
|
std::vector<size_t> strides,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Force the input array to be contiguous
|
auto copied_shape = shape; // |shape| will be moved
|
||||||
auto x = reshape(a, {-1}, s);
|
auto dtype = a.dtype(); // |a| will be moved
|
||||||
return array(
|
return array(
|
||||||
shape,
|
std::move(copied_shape),
|
||||||
a.dtype(),
|
dtype,
|
||||||
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
|
std::make_shared<AsStrided>(
|
||||||
{x});
|
to_stream(s), std::move(shape), std::move(strides), offset),
|
||||||
|
// Force the input array to be contiguous.
|
||||||
|
{reshape(std::move(a), {-1}, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
array copy(array a, StreamOrDevice s /* = {} */) {
|
||||||
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
|
auto copied_shape = a.shape(); // |a| will be moved
|
||||||
|
auto dtype = a.dtype();
|
||||||
|
return array(
|
||||||
|
std::move(copied_shape),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<Copy>(to_stream(s)),
|
||||||
|
{std::move(a)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array full(
|
array full(
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
const array& vals,
|
array vals,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
|
||||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||||
}
|
}
|
||||||
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
auto copied_shape = shape; // |shape| will be moved
|
||||||
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
|
return array(
|
||||||
|
std::move(copied_shape),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<Full>(to_stream(s)),
|
||||||
|
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array full(
|
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
|
||||||
const std::vector<int>& shape,
|
auto dtype = vals.dtype(); // |vals| will be moved
|
||||||
const array& vals,
|
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||||
StreamOrDevice s /* = {} */) {
|
|
||||||
return full(shape, vals, vals.dtype(), to_stream(s));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array zeros(
|
array zeros(
|
||||||
|
27
mlx/ops.h
27
mlx/ops.h
@ -41,40 +41,33 @@ array linspace(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Convert an array to the given data type. */
|
/** Convert an array to the given data type. */
|
||||||
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Create a view of an array with the given shape and strides. */
|
/** Create a view of an array with the given shape and strides. */
|
||||||
array as_strided(
|
array as_strided(
|
||||||
const array& a,
|
array a,
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
std::vector<size_t> strides,
|
std::vector<size_t> strides,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Copy another array. */
|
/** Copy another array. */
|
||||||
array copy(const array& a, StreamOrDevice s = {});
|
array copy(array a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Fill an array of the given shape with the given value(s). */
|
/** Fill an array of the given shape with the given value(s). */
|
||||||
array full(
|
array full(
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
const array& vals,
|
array vals,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array full(
|
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
|
||||||
const std::vector<int>& shape,
|
|
||||||
const array& vals,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array full(
|
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||||
const std::vector<int>& shape,
|
return full(std::move(shape), array(val, dtype), to_stream(s));
|
||||||
T val,
|
|
||||||
Dtype dtype,
|
|
||||||
StreamOrDevice s = {}) {
|
|
||||||
return full(shape, array(val, dtype), to_stream(s));
|
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
|
||||||
return full(shape, array(val), to_stream(s));
|
return full(std::move(shape), array(val), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fill an array of the given shape with zeros. */
|
/** Fill an array of the given shape with zeros. */
|
||||||
|
@ -420,12 +420,12 @@ class AsStrided : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit AsStrided(
|
explicit AsStrided(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
const std::vector<size_t>& strides,
|
std::vector<size_t> strides,
|
||||||
size_t offset)
|
size_t offset)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
shape_(shape),
|
shape_(std::move(shape)),
|
||||||
strides_(strides),
|
strides_(std::move(strides)),
|
||||||
offset_(offset){};
|
offset_(offset){};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
Loading…
Reference in New Issue
Block a user