MLX
 
Loading...
Searching...
No Matches
array.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2#pragma once
3
4#include <algorithm>
5#include <cstdint>
6#include <functional>
7#include <memory>
8#include <vector>
9
10#include "mlx/allocator.h"
11#include "mlx/dtype.h"
12#include "mlx/event.h"
13
14namespace mlx::core {
15
16// Forward declaration
17class Primitive;
18
19using Deleter = std::function<void(allocator::Buffer)>;
20using ShapeElem = int32_t;
21using Shape = std::vector<ShapeElem>;
22using Strides = std::vector<int64_t>;
23
24class array {
25 /* An array is really a node in a graph. It contains a shared ArrayDesc
26 * object */
27
28 public:
30 template <typename T>
31 explicit array(T val, Dtype dtype = TypeToDtype<T>());
32
33 /* Special case since std::complex can't be implicitly converted to other
34 * types. */
35 explicit array(const std::complex<float>& val, Dtype dtype = complex64);
36
37 template <typename It>
38 explicit array(
39 It data,
41 Dtype dtype =
42 TypeToDtype<typename std::iterator_traits<It>::value_type>());
43
44 template <typename T>
45 explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
46
47 /* Special case so empty lists default to float32. */
48 explicit array(std::initializer_list<float> data);
49
50 /* Special case so array({}, type) is an empty array. */
51 explicit array(std::initializer_list<int> data, Dtype dtype);
52
53 template <typename T>
54 explicit array(
55 std::initializer_list<T> data,
58
59 /* Build an array from a buffer */
60 explicit array(
64 Deleter deleter = allocator::free);
65
67 array& operator=(const array& other) && = delete;
68 array& operator=(array&& other) && = delete;
69
71 array& operator=(array&& other) & = default;
72 array(const array& other) = default;
73 array(array&& other) = default;
74
75 array& operator=(const array& other) & {
76 if (this->id() != other.id()) {
77 this->array_desc_ = other.array_desc_;
78 }
79 return *this;
80 }
81
83 size_t itemsize() const {
84 return size_of(dtype());
85 }
86
88 size_t size() const {
89 return array_desc_->size;
90 }
91
93 size_t nbytes() const {
94 return size() * itemsize();
95 }
96
98 size_t ndim() const {
99 return array_desc_->shape.size();
100 }
101
103 const Shape& shape() const {
104 return array_desc_->shape;
105 }
106
112 auto shape(int dim) const {
113 return shape().at(dim < 0 ? dim + ndim() : dim);
114 }
115
117 const Strides& strides() const {
118 return array_desc_->strides;
119 }
120
126 auto strides(int dim) const {
127 return strides().at(dim < 0 ? dim + ndim() : dim);
128 }
129
131 Dtype dtype() const {
132 return array_desc_->dtype;
133 }
134
136 void eval();
137
139 template <typename T>
140 T item();
141
142 template <typename T>
143 T item() const;
144
146 using iterator_category = std::random_access_iterator_tag;
147 using difference_type = size_t;
148 using value_type = const array;
150
151 explicit ArrayIterator(const array& arr, int idx = 0);
152
154
156 idx += diff;
157 return *this;
158 }
159
161 idx++;
162 return *this;
163 }
164
165 friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
166 return a.arr.id() == b.arr.id() && a.idx == b.idx;
167 }
168 friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
169 return !(a == b);
170 }
171
172 private:
173 const array& arr;
174 int idx;
175 };
176
178 return ArrayIterator(*this);
179 }
181 return ArrayIterator(*this, shape(0));
182 }
183
189
191 Shape shape,
192 Dtype dtype,
193 std::shared_ptr<Primitive> primitive,
194 std::vector<array> inputs);
195
196 static std::vector<array> make_arrays(
197 std::vector<Shape> shapes,
198 const std::vector<Dtype>& dtypes,
199 const std::shared_ptr<Primitive>& primitive,
200 const std::vector<array>& inputs);
201
203 std::uintptr_t id() const {
204 return reinterpret_cast<std::uintptr_t>(array_desc_.get());
205 }
206
208 std::uintptr_t primitive_id() const {
209 return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
210 }
211
212 struct Data {
217 // Not copyable
218 Data(const Data& d) = delete;
219 Data& operator=(const Data& d) = delete;
221 d(buffer);
222 }
223 };
224
225 struct Flags {
226 // True iff there are no gaps in the underlying data. Each item
227 // in the underlying data buffer belongs to at least one index.
228 //
229 // True iff:
230 // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
231 bool contiguous : 1;
232
233 // True iff:
234 // strides[-1] == 1 and
235 // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
236 // range(ndim - 1))
238
239 // True iff:
240 // strides[0] == 1 and
241 // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
242 // range(1, ndim))
244 };
245
249 explicit array(
251 Shape shape,
252 Dtype dtype,
254 size_t data_size,
255 Flags flags,
256 Deleter deleter = allocator::free);
257
260 return *(array_desc_->primitive);
261 }
262
264 std::shared_ptr<Primitive>& primitive_ptr() const {
265 return array_desc_->primitive;
266 }
267
269 bool has_primitive() const {
270 return array_desc_->primitive != nullptr;
271 }
272
274 const std::vector<array>& inputs() const {
275 return array_desc_->inputs;
276 }
277
278 std::vector<array>& inputs() {
279 return array_desc_->inputs;
280 }
281
283 bool is_donatable() const {
284 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
285 }
286
288 const std::vector<array>& siblings() const {
289 return array_desc_->siblings;
290 }
291
293 std::vector<array>& siblings() {
294 return array_desc_->siblings;
295 }
296
297 void set_siblings(std::vector<array> siblings, uint16_t position) {
298 array_desc_->siblings = std::move(siblings);
299 array_desc_->position = position;
300 }
301
304 std::vector<array> outputs() const {
305 auto idx = array_desc_->position;
306 std::vector<array> outputs;
307 outputs.reserve(siblings().size() + 1);
308 outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
309 outputs.push_back(*this);
310 outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
311 return outputs;
312 }
313
315 void detach();
316
318 const Flags& flags() const {
319 return array_desc_->flags;
320 }
321
332 size_t data_size() const {
333 return array_desc_->data_size;
334 }
335
337 return array_desc_->data->buffer;
338 }
339 const allocator::Buffer& buffer() const {
340 return array_desc_->data->buffer;
341 }
342
343 size_t buffer_size() const {
344 return allocator::allocator().size(buffer());
345 }
346
347 // Return a copy of the shared pointer
348 // to the array::Data struct
349 std::shared_ptr<Data> data_shared_ptr() const {
350 return array_desc_->data;
351 }
352 // Return a raw pointer to the arrays data
353 template <typename T>
354 T* data() {
355 return static_cast<T*>(array_desc_->data_ptr);
356 }
357
358 template <typename T>
359 const T* data() const {
360 return static_cast<T*>(array_desc_->data_ptr);
361 }
362
363 enum Status {
364 // The ouptut of a computation which has not been scheduled.
365 // For example, the status of `x` in `auto x = a + b`.
367
368 // The ouptut of a computation which has been scheduled but `eval_*` has
369 // not yet been called on the array's primitive. A possible
370 // status of `x` in `auto x = a + b; eval(x);`
372
373 // The array's `eval_*` function has been run, but the computation is not
374 // necessarily complete. The array will have memory allocated and if it is
375 // not a tracer then it will be detached from the graph.
377
378 // If the array is the output of a computation then the computation
379 // is complete. Constant arrays are always available (e.g. `array({1, 2,
380 // 3})`)
382 };
383
384 // Check if the array is safe to read.
385 bool is_available() const;
386
387 // Wait on the array to be available. After this `is_available` returns
388 // `true`.
389 void wait();
390
391 Status status() const {
392 return array_desc_->status;
393 }
394
395 void set_status(Status s) const {
396 array_desc_->status = s;
397 }
398
399 // Get the array's shared event
400 Event& event() const {
401 return array_desc_->event;
402 }
403
404 // Attach an event to a not yet evaluated array
405 void attach_event(Event e) const {
406 array_desc_->event = std::move(e);
407 }
408
409 // Mark the array as a tracer array (true) or not.
411 array_desc_->is_tracer = is_tracer;
412 }
413 // Check if the array is a tracer array
414 bool is_tracer() const;
415
417
420 size_t data_size,
422 Flags flags,
424
426 const array& other,
427 const Strides& strides,
428 Flags flags,
429 size_t data_size,
430 size_t offset = 0);
431
432 void copy_shared_buffer(const array& other);
433
435 array other,
436 const Strides& strides,
437 Flags flags,
438 size_t data_size,
439 size_t offset = 0);
440
442
443 void overwrite_descriptor(const array& other) {
444 array_desc_ = other.array_desc_;
445 }
446
448
449 private:
450 // Initialize the arrays data
451 template <typename It>
452 void init(const It src);
453
454 struct ArrayDesc {
455 Shape shape;
457 size_t size;
458 Dtype dtype;
459 std::shared_ptr<Primitive> primitive;
460
462
463 // An event on the array used for synchronization
464 Event event;
465
466 // Indicates an array is being used in a graph transform
467 // and should not be detached from the graph
468 bool is_tracer{false};
469
470 // This is a shared pointer so that *different* arrays
471 // can share the underlying data buffer.
472 std::shared_ptr<Data> data;
473
474 // Properly offset data pointer
475 void* data_ptr{nullptr};
476
477 // The size in elements of the data buffer the array accesses
478 size_t data_size;
479
480 // Contains useful meta data about the array
481 Flags flags;
482
483 std::vector<array> inputs;
484 // An array to keep track of the siblings from a multi-output
485 // primitive.
486 std::vector<array> siblings;
487 // The arrays position in the output list
488 uint32_t position{0};
489
490 explicit ArrayDesc(Shape shape, Dtype dtype);
491
492 explicit ArrayDesc(
493 Shape shape,
494 Dtype dtype,
495 std::shared_ptr<Primitive> primitive,
496 std::vector<array> inputs);
497
498 ~ArrayDesc();
499
500 private:
501 // Initialize size, strides, and other metadata
502 void init();
503 };
504
505 // The ArrayDesc contains the details of the materialized array including the
506 // shape, strides, the data type. It also includes
507 // the primitive which knows how to compute the array's data from its inputs
508 // and the list of array's inputs for the primitive.
509 std::shared_ptr<ArrayDesc> array_desc_;
510};
511
512template <typename T>
513array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
514 : array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
515 init(&val);
516}
517
518template <typename It>
520 It data,
521 Shape shape,
522 Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
523 array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
524 init(data);
525}
526
527template <typename T>
529 std::initializer_list<T> data,
530 Dtype dtype /* = TypeToDtype<T>() */)
531 : array_desc_(std::make_shared<ArrayDesc>(
532 Shape{static_cast<ShapeElem>(data.size())},
533 dtype)) {
534 init(data.begin());
535}
536
537template <typename T>
539 std::initializer_list<T> data,
540 Shape shape,
541 Dtype dtype /* = TypeToDtype<T>() */)
542 : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
543 if (data.size() != size()) {
544 throw std::invalid_argument(
545 "Data size and provided shape mismatch in array construction.");
546 }
547 init(data.begin());
548}
549
550template <typename T>
552 if (size() != 1) {
553 throw std::invalid_argument("item can only be called on arrays of size 1.");
554 }
555 eval();
556 return *data<T>();
557}
558
559template <typename T>
560T array::item() const {
561 if (size() != 1) {
562 throw std::invalid_argument("item can only be called on arrays of size 1.");
563 }
564 if (status() == Status::unscheduled) {
565 throw std::invalid_argument(
566 "item() const can only be called on evaled arrays");
567 }
568 const_cast<array*>(this)->eval();
569 return *data<T>();
570}
571
572template <typename It>
573void array::init(It src) {
575 switch (dtype()) {
576 case bool_:
577 std::copy(src, src + size(), data<bool>());
578 break;
579 case uint8:
580 std::copy(src, src + size(), data<uint8_t>());
581 break;
582 case uint16:
583 std::copy(src, src + size(), data<uint16_t>());
584 break;
585 case uint32:
586 std::copy(src, src + size(), data<uint32_t>());
587 break;
588 case uint64:
589 std::copy(src, src + size(), data<uint64_t>());
590 break;
591 case int8:
592 std::copy(src, src + size(), data<int8_t>());
593 break;
594 case int16:
595 std::copy(src, src + size(), data<int16_t>());
596 break;
597 case int32:
598 std::copy(src, src + size(), data<int32_t>());
599 break;
600 case int64:
601 std::copy(src, src + size(), data<int64_t>());
602 break;
603 case float16:
604 std::copy(src, src + size(), data<float16_t>());
605 break;
606 case float32:
607 std::copy(src, src + size(), data<float>());
608 break;
609 case float64:
610 std::copy(src, src + size(), data<double>());
611 break;
612 case bfloat16:
613 std::copy(src, src + size(), data<bfloat16_t>());
614 break;
615 case complex64:
616 std::copy(src, src + size(), data<complex64_t>());
617 break;
618 }
619}
620
621/* Utilities for determining whether a template parameter is array. */
622template <typename T>
623inline constexpr bool is_array_v =
624 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
625
626template <typename... T>
627inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
628
629template <typename... T>
630using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
631
632} // namespace mlx::core
Definition event.h:11
Definition primitives.h:48
virtual size_t size(Buffer buffer) const =0
Definition allocator.h:12
Definition array.h:24
void attach_event(Event e) const
Definition array.h:405
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
Event & event() const
Definition array.h:400
Status
Definition array.h:363
@ available
Definition array.h:381
@ evaluated
Definition array.h:376
@ unscheduled
Definition array.h:366
@ scheduled
Definition array.h:371
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
array(allocator::Buffer data, Shape shape, Dtype dtype, Strides strides, size_t data_size, Flags flags, Deleter deleter=allocator::free)
Build an array from all the info held by the array description.
void eval()
Evaluate the array.
const Strides & strides() const
The strides of the array.
Definition array.h:117
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:274
array(const array &other)=default
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:304
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
void move_shared_buffer(array other)
static std::vector< array > make_arrays(std::vector< Shape > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
array(std::initializer_list< float > data)
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:283
array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter=allocator::free)
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:264
size_t ndim() const
The number of dimensions of the array.
Definition array.h:98
size_t size() const
The number of elements in the array.
Definition array.h:88
array & operator=(array &&other) &&=delete
array & operator=(const array &other) &
Definition array.h:75
ArrayIterator end() const
Definition array.h:180
array(std::initializer_list< int > data, Dtype dtype)
void set_data(allocator::Buffer buffer, size_t data_size, Strides strides, Flags flags, Deleter d=allocator::free)
const allocator::Buffer & buffer() const
Definition array.h:339
void set_status(Status s) const
Definition array.h:395
array(const std::complex< float > &val, Dtype dtype=complex64)
Status status() const
Definition array.h:391
std::vector< array > & siblings()
The array's siblings.
Definition array.h:293
T * data()
Definition array.h:354
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:513
ArrayIterator begin() const
Definition array.h:177
Primitive & primitive() const
The array's primitive.
Definition array.h:259
void detach()
Detach the array from the graph.
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:297
T item()
Get the value from a scalar array.
Definition array.h:551
size_t buffer_size() const
Definition array.h:343
void copy_shared_buffer(const array &other)
void overwrite_descriptor(const array &other)
Definition array.h:443
const T * data() const
Definition array.h:359
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:269
allocator::Buffer & buffer()
Definition array.h:336
array(array &&other)=default
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:349
array(Shape shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
auto shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:112
auto strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:126
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:288
std::vector< array > & inputs()
Definition array.h:278
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
void move_shared_buffer(array other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:203
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
bool is_available() const
void set_tracer(bool is_tracer)
Definition array.h:410
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:208
bool is_tracer() const
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc(size_t size)
void free(Buffer buffer)
Allocator & allocator()
Definition allocator.h:7
constexpr bool is_array_v
Definition array.h:623
constexpr Dtype bool_
Definition dtype.h:68
int32_t ShapeElem
Definition array.h:20
constexpr Dtype uint64
Definition dtype.h:73
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
constexpr Dtype bfloat16
Definition dtype.h:83
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
constexpr Dtype int16
Definition dtype.h:76
std::vector< int64_t > Strides
Definition array.h:22
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
constexpr bool is_arrays_v
Definition array.h:627
constexpr Dtype uint8
Definition dtype.h:70
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
uint8_t size_of(const Dtype &t)
Definition dtype.h:104
std::function< void(allocator::Buffer)> Deleter
Definition array.h:19
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:630
constexpr Dtype complex64
Definition dtype.h:84
Definition dtype.h:13
Definition dtype.h:111
Definition array.h:145
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:165
std::random_access_iterator_tag iterator_category
Definition array.h:146
ArrayIterator & operator++()
Definition array.h:160
value_type reference
Definition array.h:149
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:168
ArrayIterator(const array &arr, int idx=0)
size_t difference_type
Definition array.h:147
const array value_type
Definition array.h:148
ArrayIterator & operator+(difference_type diff)
Definition array.h:155
Deleter d
Definition array.h:214
Data(allocator::Buffer buffer, Deleter d=allocator::free)
Definition array.h:215
~Data()
Definition array.h:220
Data(const Data &d)=delete
Data & operator=(const Data &d)=delete
allocator::Buffer buffer
Definition array.h:213
Definition array.h:225
bool row_contiguous
Definition array.h:237
bool col_contiguous
Definition array.h:243
bool contiguous
Definition array.h:231