mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
55
mlx/array.h
55
mlx/array.h
@@ -73,32 +73,32 @@ class array {
|
||||
this->array_desc_ = other.array_desc_;
|
||||
}
|
||||
return *this;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
return size_of(dtype());
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
return array_desc_->size;
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
};
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of the corresponding dimension.
|
||||
@@ -107,12 +107,12 @@ class array {
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
@@ -121,12 +121,12 @@ class array {
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
};
|
||||
}
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval();
|
||||
@@ -160,10 +160,10 @@ class array {
|
||||
|
||||
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||
};
|
||||
}
|
||||
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return !(a == b);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
const array& arr;
|
||||
@@ -209,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {};
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
@@ -230,22 +230,22 @@ class array {
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
};
|
||||
}
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
};
|
||||
}
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's inputs. */
|
||||
const std::vector<array>& inputs() const {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
@@ -259,12 +259,12 @@ class array {
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
@@ -281,7 +281,7 @@ class array {
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
@@ -289,19 +289,19 @@ class array {
|
||||
/** Get the Flags bit-field. */
|
||||
const Flags& flags() const {
|
||||
return array_desc_->flags;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
};
|
||||
}
|
||||
|
||||
allocator::Buffer& buffer() {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
const allocator::Buffer& buffer() const {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
@@ -312,19 +312,20 @@ class array {
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user