mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 03:36:40 +08:00
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can unify array's constructors to just take the arguments by value. There are 2 cases: 1. When shape is a lvalue, it will be copied into array's constructor and then moved into ArrayDesc's member. So only 1 copy happens. 2. When shape is a rvalue, it will be moved into array's constructor and then moved into ArrayDesc's member. So no copy happens. So having 1 constructor that takes by value is equivalent to having 2 constructors that const reference and rvalue separately.
This commit is contained in:
parent
db6796ac61
commit
73a8c090e0
@ -36,22 +36,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
|||||||
init(&cval);
|
init(&cval);
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
Dtype dtype,
|
|
||||||
std::shared_ptr<Primitive> primitive,
|
|
||||||
const std::vector<array>& inputs)
|
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
|
||||||
shape,
|
|
||||||
dtype,
|
|
||||||
std::move(primitive),
|
|
||||||
inputs)) {}
|
|
||||||
|
|
||||||
array::array(
|
array::array(
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array>&& inputs)
|
std::vector<array> inputs)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
dtype,
|
dtype,
|
||||||
@ -92,10 +81,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
|||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(
|
array::array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
deleter_t deleter)
|
deleter_t deleter)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
set_data(data, deleter);
|
set_data(data, deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,31 +170,16 @@ void array::move_shared_buffer(array other) {
|
|||||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||||
: shape(shape), dtype(dtype) {
|
: shape(std::move(shape)), dtype(dtype) {
|
||||||
std::tie(size, strides) = cum_prod(shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
Dtype dtype,
|
|
||||||
std::shared_ptr<Primitive> primitive,
|
|
||||||
const std::vector<array>& inputs)
|
|
||||||
: shape(shape),
|
|
||||||
dtype(dtype),
|
|
||||||
primitive(std::move(primitive)),
|
|
||||||
inputs(inputs) {
|
|
||||||
std::tie(size, strides) = cum_prod(this->shape);
|
std::tie(size, strides) = cum_prod(this->shape);
|
||||||
for (auto& in : this->inputs) {
|
|
||||||
is_tracer |= in.is_tracer();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(
|
array::ArrayDesc::ArrayDesc(
|
||||||
std::vector<int>&& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array>&& inputs)
|
std::vector<array> inputs)
|
||||||
: shape(std::move(shape)),
|
: shape(std::move(shape)),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
|
34
mlx/array.h
34
mlx/array.h
@ -32,7 +32,7 @@ class array {
|
|||||||
template <typename It>
|
template <typename It>
|
||||||
array(
|
array(
|
||||||
It data,
|
It data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype =
|
Dtype dtype =
|
||||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||||
|
|
||||||
@ -48,13 +48,13 @@ class array {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
array(
|
array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
array(
|
array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
deleter_t deleter = allocator::free);
|
deleter_t deleter = allocator::free);
|
||||||
|
|
||||||
@ -173,17 +173,11 @@ class array {
|
|||||||
* API may change.
|
* API may change.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
array(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
Dtype dtype,
|
|
||||||
std::shared_ptr<Primitive> primitive,
|
|
||||||
const std::vector<array>& inputs);
|
|
||||||
|
|
||||||
array(
|
array(
|
||||||
std::vector<int> shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array>&& inputs);
|
std::vector<array> inputs);
|
||||||
|
|
||||||
static std::vector<array> make_arrays(
|
static std::vector<array> make_arrays(
|
||||||
const std::vector<std::vector<int>>& shapes,
|
const std::vector<std::vector<int>>& shapes,
|
||||||
@ -391,19 +385,13 @@ class array {
|
|||||||
// The arrays position in the output list
|
// The arrays position in the output list
|
||||||
uint32_t position{0};
|
uint32_t position{0};
|
||||||
|
|
||||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||||
|
|
||||||
explicit ArrayDesc(
|
explicit ArrayDesc(
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
const std::vector<array>& inputs);
|
std::vector<array> inputs);
|
||||||
|
|
||||||
explicit ArrayDesc(
|
|
||||||
std::vector<int>&& shape,
|
|
||||||
Dtype dtype,
|
|
||||||
std::shared_ptr<Primitive> primitive,
|
|
||||||
std::vector<array>&& inputs);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// The ArrayDesc contains the details of the materialized array including the
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
|||||||
template <typename It>
|
template <typename It>
|
||||||
array::array(
|
array::array(
|
||||||
It data,
|
It data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||||
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
init(data);
|
init(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -441,9 +429,9 @@ array::array(
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
array::array(
|
array::array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
const std::vector<int>& shape,
|
std::vector<int> shape,
|
||||||
Dtype dtype /* = TypeToDtype<T>() */)
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
if (data.size() != size()) {
|
if (data.size() != size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Data size and provided shape mismatch in array construction.");
|
"Data size and provided shape mismatch in array construction.");
|
||||||
|
Loading…
Reference in New Issue
Block a user