mlx/mlx/array.cpp
2025-02-12 10:52:30 -08:00

345 lines
9.7 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
init(&cval);
}
array::array(
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
for (auto& in : this->inputs()) {
if (in.dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
if (this->dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
}
std::vector<array> array::make_arrays(
std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
}
return outputs;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},
float32)) {
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->primitive = nullptr;
}
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait();
set_status(Status::available);
}
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this});
} else {
wait();
}
}
bool array::is_tracer() const {
return (array_desc_->is_tracer && detail::in_tracing()) ||
detail::retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
}
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
Strides strides,
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
}
void array::copy_shared_buffer(
const array& other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (const auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destroy their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
std::unordered_map<std::uintptr_t, array> input_map;
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
}
ad.inputs.clear();
for (auto& [_, a] : input_map) {
bool is_deletable =
(a.array_desc_.use_count() <= a.siblings().size() + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
if (!is_deletable) {
break;
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
}
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = Shape(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();
shape.erase(shape.begin());
start[0] = idx;
end[0] = idx + 1;
return reshape(slice(arr, start, end), shape);
};
} // namespace mlx::core