mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
This commit is contained in:
committed by
GitHub
parent
5fd11c347d
commit
28eac18571
15
mlx/array.h
15
mlx/array.h
@@ -121,6 +121,9 @@ class array {
|
||||
template <typename T>
|
||||
T item();
|
||||
|
||||
template <typename T>
|
||||
T item() const;
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = size_t;
|
||||
@@ -454,6 +457,18 @@ T array::item() {
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
void array::init(It src) {
|
||||
set_data(allocator::malloc(size() * size_of(dtype())));
|
||||
|
||||
Reference in New Issue
Block a user