mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can unify array's constructors to just take the arguments by value. There are 2 cases: 1. When shape is a lvalue, it will be copied into array's constructor and then moved into ArrayDesc's member. So only 1 copy happens. 2. When shape is a rvalue, it will be moved into array's constructor and then moved into ArrayDesc's member. So no copy happens. So having 1 constructor that takes by value is equivalent to having 2 constructors that const reference and rvalue separately.
This commit is contained in:
34
mlx/array.h
34
mlx/array.h
@@ -32,7 +32,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@@ -48,13 +48,13 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
|
||||
@@ -173,17 +173,11 @@ class array {
|
||||
* API may change.
|
||||
*/
|
||||
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
@@ -391,19 +385,13 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
@@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
}
|
||||
|
||||
@@ -441,9 +429,9 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
|
||||
Reference in New Issue
Block a user