mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 19:31:16 +08:00
437 lines
11 KiB
C++
437 lines
11 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#pragma once
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "mlx/allocator.h"
|
|
#include "mlx/dtype.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
// Forward declaration
|
|
class Primitive;
|
|
using deleter_t = std::function<void(allocator::Buffer)>;
|
|
|
|
class array {
|
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
|
* object */
|
|
|
|
public:
|
|
/** Construct a scalar array with zero dimensions. */
|
|
template <typename T>
|
|
explicit array(T val, Dtype dtype = TypeToDtype<T>());
|
|
|
|
/* Special case since std::complex can't be implicitly converted to other
|
|
* types. */
|
|
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
|
|
|
template <typename It>
|
|
array(
|
|
It data,
|
|
const std::vector<int>& shape,
|
|
Dtype dtype =
|
|
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
|
|
|
template <typename T>
|
|
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
|
|
|
/* Special case so empty lists default to float32. */
|
|
array(std::initializer_list<float> data);
|
|
|
|
template <typename T>
|
|
array(
|
|
std::initializer_list<T> data,
|
|
const std::vector<int>& shape,
|
|
Dtype dtype = TypeToDtype<T>());
|
|
|
|
/* Build an array from a buffer */
|
|
array(
|
|
allocator::Buffer data,
|
|
const std::vector<int>& shape,
|
|
Dtype dtype,
|
|
deleter_t deleter = allocator::free);
|
|
|
|
/** Assignment to rvalue does not compile. */
|
|
array& operator=(const array& other) && = delete;
|
|
array& operator=(array&& other) && = delete;
|
|
|
|
/** Default copy and move constructors otherwise. */
|
|
array& operator=(array&& other) & = default;
|
|
array(const array& other) = default;
|
|
array(array&& other) = default;
|
|
|
|
array& operator=(const array& other) & {
|
|
if (this->id() != other.id()) {
|
|
this->array_desc_ = other.array_desc_;
|
|
}
|
|
return *this;
|
|
};
|
|
|
|
/** The size of the array's datatype in bytes. */
|
|
size_t itemsize() const {
|
|
return size_of(dtype());
|
|
};
|
|
|
|
/** The number of elements in the array. */
|
|
size_t size() const {
|
|
return array_desc_->size;
|
|
};
|
|
|
|
/** The number of bytes in the array. */
|
|
size_t nbytes() const {
|
|
return size() * itemsize();
|
|
};
|
|
|
|
/** The number of dimensions of the array. */
|
|
size_t ndim() const {
|
|
return array_desc_->shape.size();
|
|
};
|
|
|
|
/** The shape of the array as a vector of integers. */
|
|
const std::vector<int>& shape() const {
|
|
return array_desc_->shape;
|
|
};
|
|
|
|
/**
|
|
* Get the size of the corresponding dimension.
|
|
*
|
|
* This function supports negative indexing and provides
|
|
* bounds checking. */
|
|
int shape(int dim) const {
|
|
return shape().at(dim < 0 ? dim + ndim() : dim);
|
|
};
|
|
|
|
/** The strides of the array. */
|
|
const std::vector<size_t>& strides() const {
|
|
return array_desc_->strides;
|
|
};
|
|
|
|
/** Get the arrays data type. */
|
|
Dtype dtype() const {
|
|
return array_desc_->dtype;
|
|
};
|
|
|
|
/** Evaluate the array. */
|
|
void eval();
|
|
|
|
/** Get the value from a scalar array. */
|
|
template <typename T>
|
|
T item();
|
|
|
|
struct ArrayIterator {
|
|
using iterator_category = std::random_access_iterator_tag;
|
|
using difference_type = size_t;
|
|
using value_type = const array;
|
|
using reference = value_type;
|
|
|
|
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
|
|
if (arr.ndim() == 0) {
|
|
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
|
}
|
|
}
|
|
|
|
reference operator*() const;
|
|
|
|
ArrayIterator& operator+(difference_type diff) {
|
|
idx += diff;
|
|
return *this;
|
|
}
|
|
|
|
ArrayIterator& operator++() {
|
|
idx++;
|
|
return *this;
|
|
}
|
|
|
|
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
|
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
|
};
|
|
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
|
return !(a == b);
|
|
};
|
|
|
|
private:
|
|
const array& arr;
|
|
int idx;
|
|
};
|
|
|
|
ArrayIterator begin() const {
|
|
return ArrayIterator(*this);
|
|
}
|
|
ArrayIterator end() const {
|
|
return ArrayIterator(*this, shape(0));
|
|
}
|
|
|
|
/**
|
|
* The following methods should be used with caution.
|
|
* They are intended for use by the backend implementation and the
|
|
* API may change.
|
|
*/
|
|
|
|
array(
|
|
const std::vector<int>& shape,
|
|
Dtype dtype,
|
|
std::unique_ptr<Primitive> primitive,
|
|
const std::vector<array>& inputs);
|
|
|
|
/** A unique identifier for an array. */
|
|
std::uintptr_t id() const {
|
|
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
|
}
|
|
|
|
struct Data {
|
|
allocator::Buffer buffer;
|
|
deleter_t d;
|
|
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
|
: buffer(buffer), d(d){};
|
|
// Not copyable
|
|
Data(const Data& d) = delete;
|
|
Data& operator=(const Data& d) = delete;
|
|
~Data() {
|
|
d(buffer);
|
|
}
|
|
};
|
|
|
|
struct Flags {
|
|
// True if there are no gaps in the underlying data. Each item
|
|
// in the underlying data buffer belongs to at least one index.
|
|
bool contiguous : 1;
|
|
|
|
bool row_contiguous : 1;
|
|
bool col_contiguous : 1;
|
|
};
|
|
|
|
/** The array's primitive. */
|
|
Primitive& primitive() const {
|
|
return *(array_desc_->primitive);
|
|
};
|
|
|
|
/** Check if the array has an attached primitive or is a leaf node. */
|
|
bool has_primitive() const {
|
|
return array_desc_->primitive != nullptr;
|
|
};
|
|
|
|
/** The array's inputs. */
|
|
const std::vector<array>& inputs() const {
|
|
return array_desc_->inputs;
|
|
};
|
|
|
|
/** A non-const reference to the array's inputs so that they can be used to
|
|
* edit the graph. */
|
|
std::vector<array>& editable_inputs() {
|
|
return array_desc_->inputs;
|
|
}
|
|
|
|
/** Detach the array from the graph. */
|
|
void detach();
|
|
|
|
/** Get the Flags bit-field. */
|
|
const Flags& flags() const {
|
|
return array_desc_->flags;
|
|
};
|
|
|
|
/** The size (in elements) of the underlying buffer the array points to. */
|
|
size_t data_size() const {
|
|
return array_desc_->data_size;
|
|
};
|
|
|
|
allocator::Buffer& buffer() {
|
|
return array_desc_->data->buffer;
|
|
};
|
|
const allocator::Buffer& buffer() const {
|
|
return array_desc_->data->buffer;
|
|
};
|
|
|
|
template <typename T>
|
|
T* data() {
|
|
return static_cast<T*>(array_desc_->data_ptr);
|
|
};
|
|
|
|
template <typename T>
|
|
const T* data() const {
|
|
return static_cast<T*>(array_desc_->data_ptr);
|
|
};
|
|
|
|
// Check if the array has been evaluated
|
|
bool is_evaled() const {
|
|
return array_desc_->data != nullptr;
|
|
}
|
|
|
|
// Mark the array as a tracer array (true) or not.
|
|
void set_tracer(bool is_tracer) {
|
|
array_desc_->is_tracer = is_tracer;
|
|
}
|
|
// Check if the array is a tracer array
|
|
bool is_tracer() const;
|
|
|
|
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
|
|
|
void set_data(
|
|
allocator::Buffer buffer,
|
|
size_t data_size,
|
|
std::vector<size_t> strides,
|
|
Flags flags,
|
|
deleter_t d = allocator::free);
|
|
|
|
void copy_shared_buffer(
|
|
const array& other,
|
|
const std::vector<size_t>& strides,
|
|
Flags flags,
|
|
size_t data_size,
|
|
size_t offset = 0);
|
|
|
|
void copy_shared_buffer(const array& other);
|
|
|
|
void overwrite_descriptor(const array& other) {
|
|
array_desc_ = other.array_desc_;
|
|
}
|
|
|
|
private:
|
|
// Initialize the arrays data
|
|
template <typename It>
|
|
void init(const It src);
|
|
|
|
struct ArrayDesc {
|
|
std::vector<int> shape;
|
|
std::vector<size_t> strides;
|
|
size_t size;
|
|
Dtype dtype;
|
|
std::unique_ptr<Primitive> primitive{nullptr};
|
|
|
|
// Indicates an array is being used in a graph transform
|
|
// and should not be detached from the graph
|
|
bool is_tracer{false};
|
|
|
|
// This is a shared pointer so that *different* arrays
|
|
// can share the underlying data buffer.
|
|
std::shared_ptr<Data> data{nullptr};
|
|
|
|
// Properly offset data pointer
|
|
void* data_ptr{nullptr};
|
|
|
|
// The size in elements of the data buffer the array accesses
|
|
// This can be different than the actual size of the array if it
|
|
// has been broadcast or irregularly strided.
|
|
size_t data_size;
|
|
|
|
// Contains useful meta data about the array
|
|
Flags flags;
|
|
|
|
std::vector<array> inputs;
|
|
|
|
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
|
|
|
explicit ArrayDesc(
|
|
const std::vector<int>& shape,
|
|
Dtype dtype,
|
|
std::unique_ptr<Primitive> primitive,
|
|
const std::vector<array>& inputs);
|
|
|
|
~ArrayDesc();
|
|
};
|
|
|
|
// The ArrayDesc contains the details of the materialized array including the
|
|
// shape, strides, the data type. It also includes
|
|
// the primitive which knows how to compute the array's data from its inputs
|
|
// and a the list of array's inputs for the primitive.
|
|
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
|
};
|
|
|
|
template <typename T>
|
|
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
|
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
|
init(&val);
|
|
}
|
|
|
|
template <typename It>
|
|
array::array(
|
|
It data,
|
|
const std::vector<int>& shape,
|
|
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
|
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
|
init(data);
|
|
}
|
|
|
|
template <typename T>
|
|
array::array(
|
|
std::initializer_list<T> data,
|
|
Dtype dtype /* = TypeToDtype<T>() */)
|
|
: array_desc_(std::make_shared<ArrayDesc>(
|
|
std::vector<int>{static_cast<int>(data.size())},
|
|
dtype)) {
|
|
init(data.begin());
|
|
}
|
|
|
|
template <typename T>
|
|
array::array(
|
|
std::initializer_list<T> data,
|
|
const std::vector<int>& shape,
|
|
Dtype dtype /* = TypeToDtype<T>() */)
|
|
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
|
if (data.size() != size()) {
|
|
throw std::invalid_argument(
|
|
"Data size and provided shape mismatch in array construction.");
|
|
}
|
|
init(data.begin());
|
|
}
|
|
|
|
template <typename T>
|
|
T array::item() {
|
|
if (size() != 1) {
|
|
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
|
}
|
|
eval();
|
|
return *data<T>();
|
|
}
|
|
|
|
template <typename It>
|
|
void array::init(It src) {
|
|
set_data(allocator::malloc(size() * size_of(dtype())));
|
|
switch (dtype()) {
|
|
case bool_:
|
|
std::copy(src, src + size(), data<bool>());
|
|
break;
|
|
case uint8:
|
|
std::copy(src, src + size(), data<uint8_t>());
|
|
break;
|
|
case uint16:
|
|
std::copy(src, src + size(), data<uint16_t>());
|
|
break;
|
|
case uint32:
|
|
std::copy(src, src + size(), data<uint32_t>());
|
|
break;
|
|
case uint64:
|
|
std::copy(src, src + size(), data<uint64_t>());
|
|
break;
|
|
case int8:
|
|
std::copy(src, src + size(), data<int8_t>());
|
|
break;
|
|
case int16:
|
|
std::copy(src, src + size(), data<int16_t>());
|
|
break;
|
|
case int32:
|
|
std::copy(src, src + size(), data<int32_t>());
|
|
break;
|
|
case int64:
|
|
std::copy(src, src + size(), data<int64_t>());
|
|
break;
|
|
case float16:
|
|
std::copy(src, src + size(), data<float16_t>());
|
|
break;
|
|
case float32:
|
|
std::copy(src, src + size(), data<float>());
|
|
break;
|
|
case bfloat16:
|
|
std::copy(src, src + size(), data<bfloat16_t>());
|
|
break;
|
|
case complex64:
|
|
std::copy(src, src + size(), data<complex64_t>());
|
|
break;
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|