33 template <
typename It>
36 std::vector<int>
shape,
38 TypeToDtype<
typename std::iterator_traits<It>::value_type>());
51 std::initializer_list<T>
data,
52 std::vector<int>
shape,
58 std::vector<int>
shape,
72 if (this->
id() != other.id()) {
73 this->array_desc_ = other.array_desc_;
85 return array_desc_->size;
95 return array_desc_->shape.size();
99 const std::vector<int>&
shape()
const {
100 return array_desc_->shape;
109 return shape().at(dim < 0 ? dim +
ndim() : dim);
114 return array_desc_->strides;
128 return array_desc_->dtype;
135 template <
typename T>
138 template <
typename T>
162 return a.arr.
id() == b.arr.
id() && a.idx == b.idx;
187 std::vector<int>
shape,
190 std::vector<array>
inputs);
193 std::vector<std::vector<int>> shapes,
194 const std::vector<Dtype>& dtypes,
195 const std::shared_ptr<Primitive>&
primitive,
196 const std::vector<array>&
inputs);
199 std::uintptr_t
id()
const {
200 return reinterpret_cast<std::uintptr_t
>(array_desc_.get());
205 return reinterpret_cast<std::uintptr_t
>(array_desc_->primitive.get());
244 return *(array_desc_->primitive);
249 return array_desc_->primitive;
254 return array_desc_->primitive !=
nullptr;
258 const std::vector<array>&
inputs()
const {
259 return array_desc_->inputs;
263 return array_desc_->inputs;
268 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
273 return array_desc_->siblings;
278 return array_desc_->siblings;
282 array_desc_->siblings = std::move(
siblings);
283 array_desc_->position = position;
289 auto idx = array_desc_->position;
303 return array_desc_->flags;
317 return array_desc_->data_size;
321 return array_desc_->data->buffer;
324 return array_desc_->data->buffer;
334 return array_desc_->data;
337 template <
typename T>
339 return static_cast<T*
>(array_desc_->data_ptr);
342 template <
typename T>
344 return static_cast<T*
>(array_desc_->data_ptr);
354 return array_desc_->status;
358 array_desc_->status = s;
363 return array_desc_->event;
368 array_desc_->event = std::move(e);
389 const std::vector<size_t>&
strides,
398 const std::vector<size_t>&
strides,
406 array_desc_ = other.array_desc_;
413 template <
typename It>
414 void init(
const It src);
417 std::vector<int> shape;
418 std::vector<size_t> strides;
421 std::shared_ptr<Primitive> primitive;
430 bool is_tracer{
false};
434 std::shared_ptr<Data> data;
437 void* data_ptr{
nullptr};
445 std::vector<array> inputs;
448 std::vector<array> siblings;
450 uint32_t position{0};
452 explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
455 std::vector<int> shape,
457 std::shared_ptr<Primitive> primitive,
458 std::vector<array> inputs);
471 std::shared_ptr<ArrayDesc> array_desc_;
476 : array_desc_(
std::make_shared<ArrayDesc>(
std::vector<int>{}, dtype)) {
480template <
typename It>
483 std::vector<int> shape,
485 array_desc_(
std::make_shared<ArrayDesc>(
std::move(shape), dtype)) {
491 std::initializer_list<T> data,
493 : array_desc_(
std::make_shared<ArrayDesc>(
494 std::vector<int>{
static_cast<int>(
data.size())},
501 std::initializer_list<T> data,
502 std::vector<int> shape,
504 : array_desc_(
std::make_shared<ArrayDesc>(
std::move(shape), dtype)) {
506 throw std::invalid_argument(
507 "Data size and provided shape mismatch in array construction.");
515 throw std::invalid_argument(
"item can only be called on arrays of size 1.");
524 throw std::invalid_argument(
"item can only be called on arrays of size 1.");
527 throw std::invalid_argument(
528 "item() const can only be called on evaled arrays");
534template <
typename It>
535void array::init(It src) {
539 std::copy(src, src +
size(), data<bool>());
542 std::copy(src, src +
size(), data<uint8_t>());
545 std::copy(src, src +
size(), data<uint16_t>());
548 std::copy(src, src +
size(), data<uint32_t>());
551 std::copy(src, src +
size(), data<uint64_t>());
554 std::copy(src, src +
size(), data<int8_t>());
557 std::copy(src, src +
size(), data<int16_t>());
560 std::copy(src, src +
size(), data<int32_t>());
563 std::copy(src, src +
size(), data<int64_t>());
566 std::copy(src, src +
size(), data<float16_t>());
569 std::copy(src, src +
size(), data<float>());
572 std::copy(src, src +
size(), data<bfloat16_t>());
575 std::copy(src, src +
size(), data<complex64_t>());
583 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
array>;
585template <
typename... T>
588template <
typename... T>
Definition primitives.h:48
virtual size_t size(Buffer buffer) const =0
Definition allocator.h:12
void attach_event(Event e) const
Definition array.h:367
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:302
Event & event() const
Definition array.h:362
static std::vector< array > make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
Status
Definition array.h:347
@ available
Definition array.h:347
@ unscheduled
Definition array.h:347
@ scheduled
Definition array.h:347
void set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
void eval()
Evaluate the array.
void copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:258
array(const array &other)=default
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:288
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
void move_shared_buffer(array other)
array(std::initializer_list< float > data)
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:267
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:248
int shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:108
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
size_t size() const
The number of elements in the array.
Definition array.h:84
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
array & operator=(array &&other) &&=delete
array & operator=(const array &other) &
Definition array.h:71
ArrayIterator end() const
Definition array.h:176
array(std::initializer_list< int > data, Dtype dtype)
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
const allocator::Buffer & buffer() const
Definition array.h:323
void set_status(Status s) const
Definition array.h:357
array(const std::complex< float > &val, Dtype dtype=complex64)
Status status() const
Definition array.h:353
std::vector< array > & siblings()
The array's siblings.
Definition array.h:277
T * data()
Definition array.h:338
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:475
ArrayIterator begin() const
Definition array.h:173
Primitive & primitive() const
The array's primitive.
Definition array.h:243
void detach()
Detach the array from the graph.
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:281
T item()
Get the value from a scalar array.
Definition array.h:513
size_t buffer_size() const
Definition array.h:327
size_t strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:122
void copy_shared_buffer(const array &other)
void overwrite_descriptor(const array &other)
Definition array.h:405
const T * data() const
Definition array.h:343
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:253
allocator::Buffer & buffer()
Definition array.h:320
array(array &&other)=default
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:333
void move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:272
std::vector< array > & inputs()
Definition array.h:262
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:199
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
bool is_available() const
Definition array.h:349
void set_tracer(bool is_tracer)
Definition array.h:372
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:79
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:204
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:316
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc(size_t size)
constexpr bool is_array_v
Definition array.h:582
constexpr Dtype bool_
Definition dtype.h:58
std::function< void(allocator::Buffer)> deleter_t
Definition array.h:18
constexpr Dtype uint64
Definition dtype.h:63
constexpr Dtype uint16
Definition dtype.h:61
constexpr Dtype bfloat16
Definition dtype.h:72
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
constexpr Dtype int16
Definition dtype.h:66
constexpr Dtype int8
Definition dtype.h:65
constexpr Dtype int64
Definition dtype.h:68
constexpr bool is_arrays_v
Definition array.h:586
constexpr Dtype uint8
Definition dtype.h:60
constexpr Dtype float16
Definition dtype.h:70
constexpr Dtype uint32
Definition dtype.h:62
uint8_t size_of(const Dtype &t)
Definition dtype.h:93
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:589
constexpr Dtype complex64
Definition dtype.h:73
reference operator*() const
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:161
std::random_access_iterator_tag iterator_category
Definition array.h:142
ArrayIterator & operator++()
Definition array.h:156
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:164
ArrayIterator(const array &arr, int idx=0)
size_t difference_type
Definition array.h:143
const array value_type
Definition array.h:144
ArrayIterator & operator+(difference_type diff)
Definition array.h:151
~Data()
Definition array.h:216
deleter_t d
Definition array.h:210
Data(const Data &d)=delete
Data & operator=(const Data &d)=delete
Data(allocator::Buffer buffer, deleter_t d=allocator::free)
Definition array.h:211
allocator::Buffer buffer
Definition array.h:209
bool row_contiguous
Definition array.h:233
bool col_contiguous
Definition array.h:239
bool contiguous
Definition array.h:227