MLX
Loading...
Searching...
No Matches
array.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2#pragma once
3
4#include <algorithm>
5#include <cstdint>
6#include <functional>
7#include <memory>
8#include <vector>
9
10#include "mlx/allocator.h"
11#include "mlx/dtype.h"
12#include "mlx/event.h"
13
14namespace mlx::core {
15
16// Forward declaration
17class Primitive;
18
19using Deleter = std::function<void(allocator::Buffer)>;
20using Shape = std::vector<int32_t>;
21using Strides = std::vector<size_t>;
22
23class array {
24 /* An array is really a node in a graph. It contains a shared ArrayDesc
25 * object */
26
27 public:
29 template <typename T>
30 explicit array(T val, Dtype dtype = TypeToDtype<T>());
31
32 /* Special case since std::complex can't be implicitly converted to other
33 * types. */
34 explicit array(const std::complex<float>& val, Dtype dtype = complex64);
35
36 template <typename It>
37 array(
38 It data,
40 Dtype dtype =
41 TypeToDtype<typename std::iterator_traits<It>::value_type>());
42
43 template <typename T>
44 array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
45
46 /* Special case so empty lists default to float32. */
47 array(std::initializer_list<float> data);
48
49 /* Special case so array({}, type) is an empty array. */
50 array(std::initializer_list<int> data, Dtype dtype);
51
52 template <typename T>
53 array(
54 std::initializer_list<T> data,
57
58 /* Build an array from a buffer */
63 Deleter deleter = allocator::free);
64
66 array& operator=(const array& other) && = delete;
67 array& operator=(array&& other) && = delete;
68
70 array& operator=(array&& other) & = default;
71 array(const array& other) = default;
72 array(array&& other) = default;
73
74 array& operator=(const array& other) & {
75 if (this->id() != other.id()) {
76 this->array_desc_ = other.array_desc_;
77 }
78 return *this;
79 }
80
82 size_t itemsize() const {
83 return size_of(dtype());
84 }
85
87 size_t size() const {
88 return array_desc_->size;
89 }
90
92 size_t nbytes() const {
93 return size() * itemsize();
94 }
95
97 size_t ndim() const {
98 return array_desc_->shape.size();
99 }
100
102 const Shape& shape() const {
103 return array_desc_->shape;
104 }
105
111 auto shape(int dim) const {
112 return shape().at(dim < 0 ? dim + ndim() : dim);
113 }
114
116 const Strides& strides() const {
117 return array_desc_->strides;
118 }
119
125 auto strides(int dim) const {
126 return strides().at(dim < 0 ? dim + ndim() : dim);
127 }
128
130 Dtype dtype() const {
131 return array_desc_->dtype;
132 }
133
135 void eval();
136
138 template <typename T>
139 T item();
140
141 template <typename T>
142 T item() const;
143
145 using iterator_category = std::random_access_iterator_tag;
146 using difference_type = size_t;
147 using value_type = const array;
149
150 explicit ArrayIterator(const array& arr, int idx = 0);
151
153
155 idx += diff;
156 return *this;
157 }
158
160 idx++;
161 return *this;
162 }
163
164 friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
165 return a.arr.id() == b.arr.id() && a.idx == b.idx;
166 }
167 friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
168 return !(a == b);
169 }
170
171 private:
172 const array& arr;
173 int idx;
174 };
175
177 return ArrayIterator(*this);
178 }
180 return ArrayIterator(*this, shape(0));
181 }
182
190 Shape shape,
191 Dtype dtype,
192 std::shared_ptr<Primitive> primitive,
193 std::vector<array> inputs);
194
195 static std::vector<array> make_arrays(
196 std::vector<Shape> shapes,
197 const std::vector<Dtype>& dtypes,
198 const std::shared_ptr<Primitive>& primitive,
199 const std::vector<array>& inputs);
200
202 std::uintptr_t id() const {
203 return reinterpret_cast<std::uintptr_t>(array_desc_.get());
204 }
205
207 std::uintptr_t primitive_id() const {
208 return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
209 }
210
211 struct Data {
216 // Not copyable
217 Data(const Data& d) = delete;
218 Data& operator=(const Data& d) = delete;
220 d(buffer);
221 }
222 };
223
224 struct Flags {
225 // True iff there are no gaps in the underlying data. Each item
226 // in the underlying data buffer belongs to at least one index.
227 //
228 // True iff:
229 // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
230 bool contiguous : 1;
231
232 // True iff:
233 // strides[-1] == 1 and
234 // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
235 // range(ndim - 1))
237
238 // True iff:
239 // strides[0] == 1 and
240 // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
241 // range(1, ndim))
243 };
244
247 return *(array_desc_->primitive);
248 }
249
251 std::shared_ptr<Primitive>& primitive_ptr() const {
252 return array_desc_->primitive;
253 }
254
256 bool has_primitive() const {
257 return array_desc_->primitive != nullptr;
258 }
259
261 const std::vector<array>& inputs() const {
262 return array_desc_->inputs;
263 }
264
265 std::vector<array>& inputs() {
266 return array_desc_->inputs;
267 }
268
270 bool is_donatable() const {
271 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
272 }
273
275 const std::vector<array>& siblings() const {
276 return array_desc_->siblings;
277 }
278
280 std::vector<array>& siblings() {
281 return array_desc_->siblings;
282 }
283
284 void set_siblings(std::vector<array> siblings, uint16_t position) {
285 array_desc_->siblings = std::move(siblings);
286 array_desc_->position = position;
287 }
288
291 std::vector<array> outputs() const {
292 auto idx = array_desc_->position;
293 std::vector<array> outputs;
294 outputs.reserve(siblings().size() + 1);
295 outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
296 outputs.push_back(*this);
297 outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
298 return outputs;
299 }
300
302 void detach();
303
305 const Flags& flags() const {
306 return array_desc_->flags;
307 }
308
319 size_t data_size() const {
320 return array_desc_->data_size;
321 }
322
324 return array_desc_->data->buffer;
325 }
326 const allocator::Buffer& buffer() const {
327 return array_desc_->data->buffer;
328 }
329
330 size_t buffer_size() const {
331 return allocator::allocator().size(buffer());
332 }
333
334 // Return a copy of the shared pointer
335 // to the array::Data struct
336 std::shared_ptr<Data> data_shared_ptr() const {
337 return array_desc_->data;
338 }
339 // Return a raw pointer to the arrays data
340 template <typename T>
341 T* data() {
342 return static_cast<T*>(array_desc_->data_ptr);
343 }
344
345 template <typename T>
346 const T* data() const {
347 return static_cast<T*>(array_desc_->data_ptr);
348 }
349
350 enum Status {
351 // The ouptut of a computation which has not been scheduled.
352 // For example, the status of `x` in `auto x = a + b`.
354
355 // The ouptut of a computation which has been scheduled but `eval_*` has
356 // not yet been called on the array's primitive. A possible
357 // status of `x` in `auto x = a + b; eval(x);`
359
360 // The array's `eval_*` function has been run, but the computation is not
361 // necessarily complete. The array will have memory allocated and if it is
362 // not a tracer then it will be detached from the graph.
364
365 // If the array is the output of a computation then the computation
366 // is complete. Constant arrays are always available (e.g. `array({1, 2,
367 // 3})`)
369 };
370
371 // Check if the array is safe to read.
372 bool is_available() const;
373
374 // Wait on the array to be available. After this `is_available` returns
375 // `true`.
376 void wait();
377
378 Status status() const {
379 return array_desc_->status;
380 }
381
382 void set_status(Status s) const {
383 array_desc_->status = s;
384 }
385
386 // Get the array's shared event
387 Event& event() const {
388 return array_desc_->event;
389 }
390
391 // Attach an event to a not yet evaluated array
392 void attach_event(Event e) const {
393 array_desc_->event = std::move(e);
394 }
395
396 // Mark the array as a tracer array (true) or not.
398 array_desc_->is_tracer = is_tracer;
399 }
400 // Check if the array is a tracer array
401 bool is_tracer() const;
402
404
407 size_t data_size,
409 Flags flags,
411
413 const array& other,
414 const Strides& strides,
415 Flags flags,
416 size_t data_size,
417 size_t offset = 0);
418
419 void copy_shared_buffer(const array& other);
420
422 array other,
423 const Strides& strides,
424 Flags flags,
425 size_t data_size,
426 size_t offset = 0);
427
429
430 void overwrite_descriptor(const array& other) {
431 array_desc_ = other.array_desc_;
432 }
433
435
436 private:
437 // Initialize the arrays data
438 template <typename It>
439 void init(const It src);
440
441 struct ArrayDesc {
442 Shape shape;
443 Strides strides;
444 size_t size;
445 Dtype dtype;
446 std::shared_ptr<Primitive> primitive;
447
448 Status status;
449
450 // An event on the array used for synchronization
451 Event event;
452
453 // Indicates an array is being used in a graph transform
454 // and should not be detached from the graph
455 bool is_tracer{false};
456
457 // This is a shared pointer so that *different* arrays
458 // can share the underlying data buffer.
459 std::shared_ptr<Data> data;
460
461 // Properly offset data pointer
462 void* data_ptr{nullptr};
463
464 // The size in elements of the data buffer the array accesses
465 size_t data_size;
466
467 // Contains useful meta data about the array
468 Flags flags;
469
470 std::vector<array> inputs;
471 // An array to keep track of the siblings from a multi-output
472 // primitive.
473 std::vector<array> siblings;
474 // The arrays position in the output list
475 uint32_t position{0};
476
477 explicit ArrayDesc(Shape shape, Dtype dtype);
478
479 explicit ArrayDesc(
480 Shape shape,
481 Dtype dtype,
482 std::shared_ptr<Primitive> primitive,
483 std::vector<array> inputs);
484
485 ~ArrayDesc();
486
487 private:
488 // Initialize size, strides, and other metadata
489 void init();
490 };
491
492 // The ArrayDesc contains the details of the materialized array including the
493 // shape, strides, the data type. It also includes
494 // the primitive which knows how to compute the array's data from its inputs
495 // and the list of array's inputs for the primitive.
496 std::shared_ptr<ArrayDesc> array_desc_;
497};
498
499template <typename T>
500array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
501 : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
502 init(&val);
503}
504
505template <typename It>
507 It data,
508 Shape shape,
509 Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
510 array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
511 init(data);
512}
513
514template <typename T>
516 std::initializer_list<T> data,
517 Dtype dtype /* = TypeToDtype<T>() */)
518 : array_desc_(std::make_shared<ArrayDesc>(
519 std::vector<int>{static_cast<int>(data.size())},
520 dtype)) {
521 init(data.begin());
522}
523
524template <typename T>
526 std::initializer_list<T> data,
527 Shape shape,
528 Dtype dtype /* = TypeToDtype<T>() */)
529 : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
530 if (data.size() != size()) {
531 throw std::invalid_argument(
532 "Data size and provided shape mismatch in array construction.");
533 }
534 init(data.begin());
535}
536
537template <typename T>
539 if (size() != 1) {
540 throw std::invalid_argument("item can only be called on arrays of size 1.");
541 }
542 eval();
543 return *data<T>();
544}
545
546template <typename T>
547T array::item() const {
548 if (size() != 1) {
549 throw std::invalid_argument("item can only be called on arrays of size 1.");
550 }
551 if (status() == Status::unscheduled) {
552 throw std::invalid_argument(
553 "item() const can only be called on evaled arrays");
554 }
555 const_cast<array*>(this)->eval();
556 return *data<T>();
557}
558
559template <typename It>
560void array::init(It src) {
562 switch (dtype()) {
563 case bool_:
564 std::copy(src, src + size(), data<bool>());
565 break;
566 case uint8:
567 std::copy(src, src + size(), data<uint8_t>());
568 break;
569 case uint16:
570 std::copy(src, src + size(), data<uint16_t>());
571 break;
572 case uint32:
573 std::copy(src, src + size(), data<uint32_t>());
574 break;
575 case uint64:
576 std::copy(src, src + size(), data<uint64_t>());
577 break;
578 case int8:
579 std::copy(src, src + size(), data<int8_t>());
580 break;
581 case int16:
582 std::copy(src, src + size(), data<int16_t>());
583 break;
584 case int32:
585 std::copy(src, src + size(), data<int32_t>());
586 break;
587 case int64:
588 std::copy(src, src + size(), data<int64_t>());
589 break;
590 case float16:
591 std::copy(src, src + size(), data<float16_t>());
592 break;
593 case float32:
594 std::copy(src, src + size(), data<float>());
595 break;
596 case bfloat16:
597 std::copy(src, src + size(), data<bfloat16_t>());
598 break;
599 case complex64:
600 std::copy(src, src + size(), data<complex64_t>());
601 break;
602 }
603}
604
605/* Utilities for determining whether a template parameter is array. */
606template <typename T>
607inline constexpr bool is_array_v =
608 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
609
610template <typename... T>
611inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
612
613template <typename... T>
614using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
615
616} // namespace mlx::core
Definition event.h:11
Definition primitives.h:48
virtual size_t size(Buffer buffer) const =0
Definition allocator.h:12
Definition array.h:23
void attach_event(Event e) const
Definition array.h:392
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:305
Event & event() const
Definition array.h:387
Status
Definition array.h:350
@ available
Definition array.h:368
@ evaluated
Definition array.h:363
@ unscheduled
Definition array.h:353
@ scheduled
Definition array.h:358
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:102
void eval()
Evaluate the array.
const Strides & strides() const
The strides of the array.
Definition array.h:116
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:261
array(const array &other)=default
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:291
size_t nbytes() const
The number of bytes in the array.
Definition array.h:92
void move_shared_buffer(array other)
static std::vector< array > make_arrays(std::vector< Shape > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
array(std::initializer_list< float > data)
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:270
array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter=allocator::free)
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:251
size_t ndim() const
The number of dimensions of the array.
Definition array.h:97
size_t size() const
The number of elements in the array.
Definition array.h:87
array & operator=(array &&other) &&=delete
array & operator=(const array &other) &
Definition array.h:74
ArrayIterator end() const
Definition array.h:179
array(std::initializer_list< int > data, Dtype dtype)
void set_data(allocator::Buffer buffer, size_t data_size, Strides strides, Flags flags, Deleter d=allocator::free)
const allocator::Buffer & buffer() const
Definition array.h:326
void set_status(Status s) const
Definition array.h:382
array(const std::complex< float > &val, Dtype dtype=complex64)
Status status() const
Definition array.h:378
std::vector< array > & siblings()
The array's siblings.
Definition array.h:280
T * data()
Definition array.h:341
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:500
ArrayIterator begin() const
Definition array.h:176
Primitive & primitive() const
The array's primitive.
Definition array.h:246
void detach()
Detach the array from the graph.
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:284
T item()
Get the value from a scalar array.
Definition array.h:538
size_t buffer_size() const
Definition array.h:330
void copy_shared_buffer(const array &other)
void overwrite_descriptor(const array &other)
Definition array.h:430
const T * data() const
Definition array.h:346
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:256
allocator::Buffer & buffer()
Definition array.h:323
array(array &&other)=default
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:336
array(Shape shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
auto shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:111
auto strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:125
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:275
std::vector< array > & inputs()
Definition array.h:265
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
void move_shared_buffer(array other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:202
Dtype dtype() const
Get the arrays data type.
Definition array.h:130
bool is_available() const
void set_tracer(bool is_tracer)
Definition array.h:397
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:82
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:207
bool is_tracer() const
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:319
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc(size_t size)
void free(Buffer buffer)
Allocator & allocator()
Definition allocator.h:7
constexpr bool is_array_v
Definition array.h:607
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr bool is_arrays_v
Definition array.h:611
constexpr Dtype uint8
Definition dtype.h:69
std::vector< int32_t > Shape
Definition array.h:20
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
std::vector< size_t > Strides
Definition array.h:21
uint8_t size_of(const Dtype &t)
Definition dtype.h:102
std::function< void(allocator::Buffer)> Deleter
Definition array.h:19
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:614
constexpr Dtype complex64
Definition dtype.h:82
Definition dtype.h:13
Definition dtype.h:109
Definition array.h:144
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:164
std::random_access_iterator_tag iterator_category
Definition array.h:145
ArrayIterator & operator++()
Definition array.h:159
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:167
ArrayIterator(const array &arr, int idx=0)
size_t difference_type
Definition array.h:146
const array value_type
Definition array.h:147
ArrayIterator & operator+(difference_type diff)
Definition array.h:154
Definition array.h:211
Deleter d
Definition array.h:213
Data(allocator::Buffer buffer, Deleter d=allocator::free)
Definition array.h:214
~Data()
Definition array.h:219
Data(const Data &d)=delete
Data & operator=(const Data &d)=delete
allocator::Buffer buffer
Definition array.h:212
Definition array.h:224
bool row_contiguous
Definition array.h:236
bool col_contiguous
Definition array.h:242
bool contiguous
Definition array.h:230