MLX
Loading...
Searching...
No Matches
array.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2#pragma once
3
4#include <algorithm>
5#include <cstdint>
6#include <functional>
7#include <memory>
8#include <vector>
9
10#include "mlx/allocator.h"
11#include "mlx/dtype.h"
12#include "mlx/event.h"
13
14namespace mlx::core {
15
16// Forward declaration
17class Primitive;
18using deleter_t = std::function<void(allocator::Buffer)>;
19
20class array {
21 /* An array is really a node in a graph. It contains a shared ArrayDesc
22 * object */
23
24 public:
26 template <typename T>
27 explicit array(T val, Dtype dtype = TypeToDtype<T>());
28
29 /* Special case since std::complex can't be implicitly converted to other
30 * types. */
31 explicit array(const std::complex<float>& val, Dtype dtype = complex64);
32
33 template <typename It>
34 array(
35 It data,
36 std::vector<int> shape,
37 Dtype dtype =
38 TypeToDtype<typename std::iterator_traits<It>::value_type>());
39
40 template <typename T>
41 array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
42
43 /* Special case so empty lists default to float32. */
44 array(std::initializer_list<float> data);
45
46 /* Special case so array({}, type) is an empty array. */
47 array(std::initializer_list<int> data, Dtype dtype);
48
49 template <typename T>
50 array(
51 std::initializer_list<T> data,
52 std::vector<int> shape,
54
55 /* Build an array from a buffer */
58 std::vector<int> shape,
60 deleter_t deleter = allocator::free);
61
63 array& operator=(const array& other) && = delete;
64 array& operator=(array&& other) && = delete;
65
67 array& operator=(array&& other) & = default;
68 array(const array& other) = default;
69 array(array&& other) = default;
70
71 array& operator=(const array& other) & {
72 if (this->id() != other.id()) {
73 this->array_desc_ = other.array_desc_;
74 }
75 return *this;
76 }
77
79 size_t itemsize() const {
80 return size_of(dtype());
81 }
82
84 size_t size() const {
85 return array_desc_->size;
86 }
87
89 size_t nbytes() const {
90 return size() * itemsize();
91 }
92
94 size_t ndim() const {
95 return array_desc_->shape.size();
96 }
97
99 const std::vector<int>& shape() const {
100 return array_desc_->shape;
101 }
102
108 int shape(int dim) const {
109 return shape().at(dim < 0 ? dim + ndim() : dim);
110 }
111
113 const std::vector<size_t>& strides() const {
114 return array_desc_->strides;
115 }
116
122 size_t strides(int dim) const {
123 return strides().at(dim < 0 ? dim + ndim() : dim);
124 }
125
127 Dtype dtype() const {
128 return array_desc_->dtype;
129 }
130
132 void eval();
133
135 template <typename T>
136 T item();
137
138 template <typename T>
139 T item() const;
140
142 using iterator_category = std::random_access_iterator_tag;
143 using difference_type = size_t;
144 using value_type = const array;
146
147 explicit ArrayIterator(const array& arr, int idx = 0);
148
150
152 idx += diff;
153 return *this;
154 }
155
157 idx++;
158 return *this;
159 }
160
161 friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
162 return a.arr.id() == b.arr.id() && a.idx == b.idx;
163 }
164 friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
165 return !(a == b);
166 }
167
168 private:
169 const array& arr;
170 int idx;
171 };
172
174 return ArrayIterator(*this);
175 }
177 return ArrayIterator(*this, shape(0));
178 }
179
187 std::vector<int> shape,
188 Dtype dtype,
189 std::shared_ptr<Primitive> primitive,
190 std::vector<array> inputs);
191
192 static std::vector<array> make_arrays(
193 std::vector<std::vector<int>> shapes,
194 const std::vector<Dtype>& dtypes,
195 const std::shared_ptr<Primitive>& primitive,
196 const std::vector<array>& inputs);
197
199 std::uintptr_t id() const {
200 return reinterpret_cast<std::uintptr_t>(array_desc_.get());
201 }
202
204 std::uintptr_t primitive_id() const {
205 return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
206 }
207
208 struct Data {
213 // Not copyable
214 Data(const Data& d) = delete;
215 Data& operator=(const Data& d) = delete;
217 d(buffer);
218 }
219 };
220
221 struct Flags {
222 // True iff there are no gaps in the underlying data. Each item
223 // in the underlying data buffer belongs to at least one index.
224 //
225 // True iff:
226 // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
227 bool contiguous : 1;
228
229 // True iff:
230 // strides[-1] == 1 and
231 // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
232 // range(ndim - 1))
234
235 // True iff:
236 // strides[0] == 1 and
237 // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
238 // range(1, ndim))
240 };
241
244 return *(array_desc_->primitive);
245 }
246
248 std::shared_ptr<Primitive>& primitive_ptr() const {
249 return array_desc_->primitive;
250 }
251
253 bool has_primitive() const {
254 return array_desc_->primitive != nullptr;
255 }
256
258 const std::vector<array>& inputs() const {
259 return array_desc_->inputs;
260 }
261
262 std::vector<array>& inputs() {
263 return array_desc_->inputs;
264 }
265
267 bool is_donatable() const {
268 return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
269 }
270
272 const std::vector<array>& siblings() const {
273 return array_desc_->siblings;
274 }
275
277 std::vector<array>& siblings() {
278 return array_desc_->siblings;
279 }
280
281 void set_siblings(std::vector<array> siblings, uint16_t position) {
282 array_desc_->siblings = std::move(siblings);
283 array_desc_->position = position;
284 }
285
288 std::vector<array> outputs() const {
289 auto idx = array_desc_->position;
290 std::vector<array> outputs;
291 outputs.reserve(siblings().size() + 1);
292 outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
293 outputs.push_back(*this);
294 outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
295 return outputs;
296 }
297
299 void detach();
300
302 const Flags& flags() const {
303 return array_desc_->flags;
304 }
305
316 size_t data_size() const {
317 return array_desc_->data_size;
318 }
319
321 return array_desc_->data->buffer;
322 }
323 const allocator::Buffer& buffer() const {
324 return array_desc_->data->buffer;
325 }
326
327 size_t buffer_size() const {
328 return allocator::allocator().size(buffer());
329 }
330
331 // Return a copy of the shared pointer
332 // to the array::Data struct
333 std::shared_ptr<Data> data_shared_ptr() const {
334 return array_desc_->data;
335 }
336 // Return a raw pointer to the arrays data
337 template <typename T>
338 T* data() {
339 return static_cast<T*>(array_desc_->data_ptr);
340 }
341
342 template <typename T>
343 const T* data() const {
344 return static_cast<T*>(array_desc_->data_ptr);
345 }
346
348
349 bool is_available() const {
350 return status() == Status::available;
351 }
352
353 Status status() const {
354 return array_desc_->status;
355 }
356
357 void set_status(Status s) const {
358 array_desc_->status = s;
359 }
360
361 // Get the array's shared event
362 Event& event() const {
363 return array_desc_->event;
364 }
365
366 // Attach an event to a not yet evaluated array
367 void attach_event(Event e) const {
368 array_desc_->event = std::move(e);
369 }
370
371 // Mark the array as a tracer array (true) or not.
373 array_desc_->is_tracer = is_tracer;
374 }
375 // Check if the array is a tracer array
376 bool is_tracer() const;
377
379
382 size_t data_size,
383 std::vector<size_t> strides,
384 Flags flags,
386
388 const array& other,
389 const std::vector<size_t>& strides,
390 Flags flags,
391 size_t data_size,
392 size_t offset = 0);
393
394 void copy_shared_buffer(const array& other);
395
397 array other,
398 const std::vector<size_t>& strides,
399 Flags flags,
400 size_t data_size,
401 size_t offset = 0);
402
404
405 void overwrite_descriptor(const array& other) {
406 array_desc_ = other.array_desc_;
407 }
408
410
411 private:
412 // Initialize the arrays data
413 template <typename It>
414 void init(const It src);
415
416 struct ArrayDesc {
417 std::vector<int> shape;
418 std::vector<size_t> strides;
419 size_t size;
420 Dtype dtype;
421 std::shared_ptr<Primitive> primitive;
422
423 Status status;
424
425 // An event on the array used for synchronization
426 Event event;
427
428 // Indicates an array is being used in a graph transform
429 // and should not be detached from the graph
430 bool is_tracer{false};
431
432 // This is a shared pointer so that *different* arrays
433 // can share the underlying data buffer.
434 std::shared_ptr<Data> data;
435
436 // Properly offset data pointer
437 void* data_ptr{nullptr};
438
439 // The size in elements of the data buffer the array accesses
440 size_t data_size;
441
442 // Contains useful meta data about the array
443 Flags flags;
444
445 std::vector<array> inputs;
446 // An array to keep track of the siblings from a multi-output
447 // primitive.
448 std::vector<array> siblings;
449 // The arrays position in the output list
450 uint32_t position{0};
451
452 explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
453
454 explicit ArrayDesc(
455 std::vector<int> shape,
456 Dtype dtype,
457 std::shared_ptr<Primitive> primitive,
458 std::vector<array> inputs);
459
460 ~ArrayDesc();
461
462 private:
463 // Initialize size, strides, and other metadata
464 void init();
465 };
466
467 // The ArrayDesc contains the details of the materialized array including the
468 // shape, strides, the data type. It also includes
469 // the primitive which knows how to compute the array's data from its inputs
470 // and the list of array's inputs for the primitive.
471 std::shared_ptr<ArrayDesc> array_desc_;
472};
473
474template <typename T>
475array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
476 : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
477 init(&val);
478}
479
480template <typename It>
482 It data,
483 std::vector<int> shape,
484 Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
485 array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
486 init(data);
487}
488
489template <typename T>
491 std::initializer_list<T> data,
492 Dtype dtype /* = TypeToDtype<T>() */)
493 : array_desc_(std::make_shared<ArrayDesc>(
494 std::vector<int>{static_cast<int>(data.size())},
495 dtype)) {
496 init(data.begin());
497}
498
499template <typename T>
501 std::initializer_list<T> data,
502 std::vector<int> shape,
503 Dtype dtype /* = TypeToDtype<T>() */)
504 : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
505 if (data.size() != size()) {
506 throw std::invalid_argument(
507 "Data size and provided shape mismatch in array construction.");
508 }
509 init(data.begin());
510}
511
512template <typename T>
514 if (size() != 1) {
515 throw std::invalid_argument("item can only be called on arrays of size 1.");
516 }
517 eval();
518 return *data<T>();
519}
520
521template <typename T>
522T array::item() const {
523 if (size() != 1) {
524 throw std::invalid_argument("item can only be called on arrays of size 1.");
525 }
526 if (status() == Status::unscheduled) {
527 throw std::invalid_argument(
528 "item() const can only be called on evaled arrays");
529 }
530 const_cast<array*>(this)->eval();
531 return *data<T>();
532}
533
534template <typename It>
535void array::init(It src) {
537 switch (dtype()) {
538 case bool_:
539 std::copy(src, src + size(), data<bool>());
540 break;
541 case uint8:
542 std::copy(src, src + size(), data<uint8_t>());
543 break;
544 case uint16:
545 std::copy(src, src + size(), data<uint16_t>());
546 break;
547 case uint32:
548 std::copy(src, src + size(), data<uint32_t>());
549 break;
550 case uint64:
551 std::copy(src, src + size(), data<uint64_t>());
552 break;
553 case int8:
554 std::copy(src, src + size(), data<int8_t>());
555 break;
556 case int16:
557 std::copy(src, src + size(), data<int16_t>());
558 break;
559 case int32:
560 std::copy(src, src + size(), data<int32_t>());
561 break;
562 case int64:
563 std::copy(src, src + size(), data<int64_t>());
564 break;
565 case float16:
566 std::copy(src, src + size(), data<float16_t>());
567 break;
568 case float32:
569 std::copy(src, src + size(), data<float>());
570 break;
571 case bfloat16:
572 std::copy(src, src + size(), data<bfloat16_t>());
573 break;
574 case complex64:
575 std::copy(src, src + size(), data<complex64_t>());
576 break;
577 }
578}
579
580/* Utilities for determining whether a template parameter is array. */
581template <typename T>
582inline constexpr bool is_array_v =
583 std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
584
585template <typename... T>
586inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
587
588template <typename... T>
589using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
590
591} // namespace mlx::core
Definition event.h:11
Definition primitives.h:48
virtual size_t size(Buffer buffer) const =0
Definition allocator.h:12
Definition array.h:20
void attach_event(Event e) const
Definition array.h:367
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:302
Event & event() const
Definition array.h:362
static std::vector< array > make_arrays(std::vector< std::vector< int > > shapes, const std::vector< Dtype > &dtypes, const std::shared_ptr< Primitive > &primitive, const std::vector< array > &inputs)
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
Status
Definition array.h:347
@ available
Definition array.h:347
@ unscheduled
Definition array.h:347
@ scheduled
Definition array.h:347
void set_data(allocator::Buffer buffer, size_t data_size, std::vector< size_t > strides, Flags flags, deleter_t d=allocator::free)
void eval()
Evaluate the array.
void copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
const std::vector< array > & inputs() const
The array's inputs.
Definition array.h:258
array(const array &other)=default
std::vector< array > outputs() const
The outputs of the array's primitive (i.e.
Definition array.h:288
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
void move_shared_buffer(array other)
array(std::initializer_list< float > data)
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:267
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
std::shared_ptr< Primitive > & primitive_ptr() const
A shared pointer to the array's primitive.
Definition array.h:248
int shape(int dim) const
Get the size of the corresponding dimension.
Definition array.h:108
size_t ndim() const
The number of dimensions of the array.
Definition array.h:94
size_t size() const
The number of elements in the array.
Definition array.h:84
array(allocator::Buffer data, std::vector< int > shape, Dtype dtype, deleter_t deleter=allocator::free)
array & operator=(array &&other) &&=delete
array & operator=(const array &other) &
Definition array.h:71
ArrayIterator end() const
Definition array.h:176
array(std::initializer_list< int > data, Dtype dtype)
void set_data(allocator::Buffer buffer, deleter_t d=allocator::free)
const allocator::Buffer & buffer() const
Definition array.h:323
void set_status(Status s) const
Definition array.h:357
array(const std::complex< float > &val, Dtype dtype=complex64)
Status status() const
Definition array.h:353
std::vector< array > & siblings()
The array's siblings.
Definition array.h:277
T * data()
Definition array.h:338
array(T val, Dtype dtype=TypeToDtype< T >())
Construct a scalar array with zero dimensions.
Definition array.h:475
ArrayIterator begin() const
Definition array.h:173
Primitive & primitive() const
The array's primitive.
Definition array.h:243
void detach()
Detach the array from the graph.
array & operator=(const array &other) &&=delete
Assignment to rvalue does not compile.
void set_siblings(std::vector< array > siblings, uint16_t position)
Definition array.h:281
T item()
Get the value from a scalar array.
Definition array.h:513
size_t buffer_size() const
Definition array.h:327
size_t strides(int dim) const
Get the stride of the corresponding dimension.
Definition array.h:122
void copy_shared_buffer(const array &other)
void overwrite_descriptor(const array &other)
Definition array.h:405
const T * data() const
Definition array.h:343
bool has_primitive() const
Check if the array has an attached primitive or is a leaf node.
Definition array.h:253
allocator::Buffer & buffer()
Definition array.h:320
array(array &&other)=default
std::shared_ptr< Data > data_shared_ptr() const
Definition array.h:333
void move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)
const std::vector< array > & siblings() const
The array's siblings.
Definition array.h:272
std::vector< array > & inputs()
Definition array.h:262
array & operator=(array &&other) &=default
Default copy and move constructors otherwise.
array(std::vector< int > shape, Dtype dtype, std::shared_ptr< Primitive > primitive, std::vector< array > inputs)
The following methods should be used with caution.
std::uintptr_t id() const
A unique identifier for an array.
Definition array.h:199
Dtype dtype() const
Get the arrays data type.
Definition array.h:127
bool is_available() const
Definition array.h:349
void set_tracer(bool is_tracer)
Definition array.h:372
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:79
std::uintptr_t primitive_id() const
A unique identifier for an arrays primitive.
Definition array.h:204
bool is_tracer() const
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:316
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
Buffer malloc(size_t size)
void free(Buffer buffer)
Allocator & allocator()
Definition allocator.h:7
constexpr bool is_array_v
Definition array.h:582
constexpr Dtype bool_
Definition dtype.h:58
std::function< void(allocator::Buffer)> deleter_t
Definition array.h:18
constexpr Dtype uint64
Definition dtype.h:63
constexpr Dtype uint16
Definition dtype.h:61
constexpr Dtype bfloat16
Definition dtype.h:72
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
constexpr Dtype int16
Definition dtype.h:66
constexpr Dtype int8
Definition dtype.h:65
constexpr Dtype int64
Definition dtype.h:68
constexpr bool is_arrays_v
Definition array.h:586
constexpr Dtype uint8
Definition dtype.h:60
constexpr Dtype float16
Definition dtype.h:70
constexpr Dtype uint32
Definition dtype.h:62
uint8_t size_of(const Dtype &t)
Definition dtype.h:93
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:589
constexpr Dtype complex64
Definition dtype.h:73
Definition dtype.h:13
Definition dtype.h:100
Definition array.h:141
friend bool operator==(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:161
std::random_access_iterator_tag iterator_category
Definition array.h:142
ArrayIterator & operator++()
Definition array.h:156
friend bool operator!=(const ArrayIterator &a, const ArrayIterator &b)
Definition array.h:164
ArrayIterator(const array &arr, int idx=0)
size_t difference_type
Definition array.h:143
const array value_type
Definition array.h:144
ArrayIterator & operator+(difference_type diff)
Definition array.h:151
Definition array.h:208
~Data()
Definition array.h:216
deleter_t d
Definition array.h:210
Data(const Data &d)=delete
Data & operator=(const Data &d)=delete
Data(allocator::Buffer buffer, deleter_t d=allocator::free)
Definition array.h:211
allocator::Buffer buffer
Definition array.h:209
Definition array.h:221
bool row_contiguous
Definition array.h:233
bool col_contiguous
Definition array.h:239
bool contiguous
Definition array.h:227