2024-01-31 08:04:45 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
#include "mlx/array.h"
|
|
|
|
#include "mlx/ops.h"
|
|
|
|
#include "mlx/primitives.h"
|
|
|
|
#include "mlx/transforms.h"
|
2024-01-08 07:16:51 +08:00
|
|
|
#include "mlx/transforms_impl.h"
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
|
|
|
std::vector<size_t> strides(shape.size());
|
|
|
|
size_t cum_prod = 1;
|
|
|
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
|
|
strides[i] = cum_prod;
|
|
|
|
cum_prod *= shape[i];
|
|
|
|
}
|
|
|
|
return {cum_prod, strides};
|
|
|
|
}
|
|
|
|
|
2024-01-08 07:16:51 +08:00
|
|
|
/** Return true if we are currently performing a function transformation in
|
|
|
|
* order to keep the graph when evaluating tracer arrays. */
|
|
|
|
bool in_tracing() {
|
|
|
|
return detail::InTracing::in_tracing();
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
|
|
|
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
|
|
|
auto cval = static_cast<complex64_t>(val);
|
|
|
|
init(&cval);
|
|
|
|
}
|
|
|
|
|
2024-01-27 05:45:30 +08:00
|
|
|
array::array(
|
|
|
|
std::vector<int> shape,
|
|
|
|
Dtype dtype,
|
|
|
|
std::shared_ptr<Primitive> primitive,
|
2024-03-20 22:54:30 +08:00
|
|
|
std::vector<array> inputs)
|
2024-01-27 05:45:30 +08:00
|
|
|
: array_desc_(std::make_shared<ArrayDesc>(
|
|
|
|
std::move(shape),
|
|
|
|
dtype,
|
|
|
|
std::move(primitive),
|
|
|
|
std::move(inputs))) {}
|
|
|
|
|
2024-01-09 08:39:08 +08:00
|
|
|
std::vector<array> array::make_arrays(
|
|
|
|
const std::vector<std::vector<int>>& shapes,
|
|
|
|
const std::vector<Dtype>& dtypes,
|
|
|
|
std::shared_ptr<Primitive> primitive,
|
|
|
|
const std::vector<array>& inputs) {
|
|
|
|
std::vector<array> outputs;
|
|
|
|
for (int i = 0; i < shapes.size(); ++i) {
|
|
|
|
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
|
|
|
}
|
|
|
|
for (int i = 0; i < outputs.size(); ++i) {
|
|
|
|
auto siblings = outputs;
|
|
|
|
siblings.erase(siblings.begin() + i);
|
|
|
|
outputs[i].set_siblings(std::move(siblings), i);
|
|
|
|
}
|
|
|
|
return outputs;
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
array::array(std::initializer_list<float> data)
|
|
|
|
: array_desc_(std::make_shared<ArrayDesc>(
|
|
|
|
std::vector<int>{static_cast<int>(data.size())},
|
|
|
|
float32)) {
|
|
|
|
init(data.begin());
|
|
|
|
}
|
|
|
|
|
2024-02-14 15:34:17 +08:00
|
|
|
array::array(std::initializer_list<int> data, Dtype dtype)
|
|
|
|
: array_desc_(std::make_shared<ArrayDesc>(
|
|
|
|
std::vector<int>{static_cast<int>(data.size())},
|
|
|
|
dtype)) {
|
|
|
|
init(data.begin());
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
/* Build an array from a shared buffer */
|
|
|
|
array::array(
|
|
|
|
allocator::Buffer data,
|
2024-03-20 22:54:30 +08:00
|
|
|
std::vector<int> shape,
|
2023-11-30 02:42:59 +08:00
|
|
|
Dtype dtype,
|
|
|
|
deleter_t deleter)
|
2024-03-20 22:54:30 +08:00
|
|
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
2023-11-30 02:42:59 +08:00
|
|
|
set_data(data, deleter);
|
|
|
|
}
|
|
|
|
|
|
|
|
void array::detach() {
|
2024-01-18 06:08:07 +08:00
|
|
|
for (auto& s : array_desc_->siblings) {
|
|
|
|
s.array_desc_->inputs.clear();
|
|
|
|
s.array_desc_->siblings.clear();
|
|
|
|
s.array_desc_->position = 0;
|
|
|
|
s.array_desc_->primitive = nullptr;
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
array_desc_->inputs.clear();
|
2024-01-09 08:39:08 +08:00
|
|
|
array_desc_->siblings.clear();
|
|
|
|
array_desc_->position = 0;
|
2023-11-30 02:42:59 +08:00
|
|
|
array_desc_->primitive = nullptr;
|
|
|
|
}
|
|
|
|
|
2024-01-08 07:16:51 +08:00
|
|
|
void array::eval() {
|
|
|
|
mlx::core::eval({*this});
|
|
|
|
}
|
|
|
|
|
|
|
|
bool array::is_tracer() const {
|
|
|
|
return array_desc_->is_tracer && in_tracing();
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
|
|
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
|
|
|
array_desc_->data_ptr = buffer.raw_ptr();
|
|
|
|
array_desc_->data_size = size();
|
|
|
|
array_desc_->flags.contiguous = true;
|
|
|
|
array_desc_->flags.row_contiguous = true;
|
|
|
|
auto max_dim = std::max_element(shape().begin(), shape().end());
|
|
|
|
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
|
|
|
}
|
|
|
|
|
|
|
|
void array::set_data(
|
|
|
|
allocator::Buffer buffer,
|
|
|
|
size_t data_size,
|
|
|
|
std::vector<size_t> strides,
|
|
|
|
Flags flags,
|
|
|
|
deleter_t d) {
|
|
|
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
|
|
|
array_desc_->data_ptr = buffer.raw_ptr();
|
|
|
|
array_desc_->data_size = data_size;
|
|
|
|
array_desc_->strides = std::move(strides);
|
|
|
|
array_desc_->flags = flags;
|
|
|
|
}
|
|
|
|
|
|
|
|
void array::copy_shared_buffer(
|
|
|
|
const array& other,
|
|
|
|
const std::vector<size_t>& strides,
|
|
|
|
Flags flags,
|
|
|
|
size_t data_size,
|
|
|
|
size_t offset /* = 0 */) {
|
|
|
|
array_desc_->data = other.array_desc_->data;
|
|
|
|
array_desc_->strides = strides;
|
|
|
|
array_desc_->flags = flags;
|
|
|
|
array_desc_->data_size = data_size;
|
|
|
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
|
|
|
array_desc_->data_ptr = static_cast<void*>(
|
|
|
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
|
|
|
}
|
|
|
|
|
|
|
|
void array::copy_shared_buffer(const array& other) {
|
|
|
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
|
|
|
}
|
|
|
|
|
2024-03-11 21:31:31 +08:00
|
|
|
void array::move_shared_buffer(
|
|
|
|
array other,
|
|
|
|
const std::vector<size_t>& strides,
|
|
|
|
Flags flags,
|
|
|
|
size_t data_size,
|
|
|
|
size_t offset /* = 0 */) {
|
2024-01-27 08:30:33 +08:00
|
|
|
array_desc_->data = std::move(other.array_desc_->data);
|
2024-03-11 21:31:31 +08:00
|
|
|
array_desc_->strides = strides;
|
|
|
|
array_desc_->flags = flags;
|
|
|
|
array_desc_->data_size = data_size;
|
|
|
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
|
|
|
array_desc_->data_ptr = static_cast<void*>(
|
|
|
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
|
|
|
}
|
|
|
|
|
|
|
|
void array::move_shared_buffer(array other) {
|
|
|
|
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
2024-01-27 08:30:33 +08:00
|
|
|
}
|
|
|
|
|
2024-03-20 22:54:30 +08:00
|
|
|
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
|
|
|
: shape(std::move(shape)), dtype(dtype) {
|
2024-01-27 05:45:30 +08:00
|
|
|
std::tie(size, strides) = cum_prod(this->shape);
|
|
|
|
}
|
|
|
|
|
|
|
|
array::ArrayDesc::ArrayDesc(
|
2024-03-20 22:54:30 +08:00
|
|
|
std::vector<int> shape,
|
2024-01-27 05:45:30 +08:00
|
|
|
Dtype dtype,
|
|
|
|
std::shared_ptr<Primitive> primitive,
|
2024-03-20 22:54:30 +08:00
|
|
|
std::vector<array> inputs)
|
2024-01-27 05:45:30 +08:00
|
|
|
: shape(std::move(shape)),
|
|
|
|
dtype(dtype),
|
|
|
|
primitive(std::move(primitive)),
|
|
|
|
inputs(std::move(inputs)) {
|
|
|
|
std::tie(size, strides) = cum_prod(this->shape);
|
2024-02-03 01:57:12 +08:00
|
|
|
for (auto& in : this->inputs) {
|
2023-11-30 02:42:59 +08:00
|
|
|
is_tracer |= in.is_tracer();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-17 05:33:55 +08:00
|
|
|
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
|
|
|
: arr(arr), idx(idx) {
|
|
|
|
if (arr.ndim() == 0) {
|
|
|
|
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
|
|
|
auto start = std::vector<int>(arr.ndim(), 0);
|
|
|
|
auto end = arr.shape();
|
|
|
|
auto shape = arr.shape();
|
|
|
|
shape.erase(shape.begin());
|
|
|
|
start[0] = idx;
|
|
|
|
end[0] = idx + 1;
|
|
|
|
return reshape(slice(arr, start, end), shape);
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace mlx::core
|