mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 03:06:39 +08:00
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can unify array's constructors to just take the arguments by value. There are 2 cases: 1. When shape is a lvalue, it will be copied into array's constructor and then moved into ArrayDesc's member. So only 1 copy happens. 2. When shape is a rvalue, it will be moved into array's constructor and then moved into ArrayDesc's member. So no copy happens. So having 1 constructor that takes by value is equivalent to having 2 constructors that const reference and rvalue separately.
This commit is contained in:
parent
db6796ac61
commit
73a8c090e0
@ -36,22 +36,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
@ -92,10 +81,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
@ -181,31 +170,16 @@ void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
: shape(shape), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
|
34
mlx/array.h
34
mlx/array.h
@ -32,7 +32,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@ -48,13 +48,13 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
|
||||
@ -173,17 +173,11 @@ class array {
|
||||
* API may change.
|
||||
*/
|
||||
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
@ -391,19 +385,13 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
}
|
||||
|
||||
@ -441,9 +429,9 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
|
Loading…
Reference in New Issue
Block a user