minor fixes (#1194)

* minor fixes

* fix build errors
This commit is contained in:
Fangjun Kuang 2024-06-13 13:06:49 +08:00 committed by GitHub
parent 934683088e
commit f20e97b092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 239 additions and 238 deletions

View File

@ -17,4 +17,4 @@ jobs:
pip install pre-commit black isort clang-format pip install pre-commit black isort clang-format
- name: Run lint - name: Run lint
run: | run: |
pre-commit run --all-files pre-commit run --all-files

View File

@ -206,7 +206,7 @@ void array::ArrayDesc::init() {
strides[i] = size; strides[i] = size;
size *= shape[i]; size *= shape[i];
} }
for (auto& in : inputs) { for (const auto& in : inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
} }
} }
@ -231,7 +231,7 @@ array::ArrayDesc::ArrayDesc(
array::ArrayDesc::~ArrayDesc() { array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays // When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so // that may also destroy their corresponding descriptions and so on and so
// forth. // forth.
// //
// This calls recursively the destructor and can result in stack overflow, we // This calls recursively the destructor and can result in stack overflow, we

View File

@ -73,32 +73,32 @@ class array {
this->array_desc_ = other.array_desc_; this->array_desc_ = other.array_desc_;
} }
return *this; return *this;
}; }
/** The size of the array's datatype in bytes. */ /** The size of the array's datatype in bytes. */
size_t itemsize() const { size_t itemsize() const {
return size_of(dtype()); return size_of(dtype());
}; }
/** The number of elements in the array. */ /** The number of elements in the array. */
size_t size() const { size_t size() const {
return array_desc_->size; return array_desc_->size;
}; }
/** The number of bytes in the array. */ /** The number of bytes in the array. */
size_t nbytes() const { size_t nbytes() const {
return size() * itemsize(); return size() * itemsize();
}; }
/** The number of dimensions of the array. */ /** The number of dimensions of the array. */
size_t ndim() const { size_t ndim() const {
return array_desc_->shape.size(); return array_desc_->shape.size();
}; }
/** The shape of the array as a vector of integers. */ /** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const { const std::vector<int>& shape() const {
return array_desc_->shape; return array_desc_->shape;
}; }
/** /**
* Get the size of the corresponding dimension. * Get the size of the corresponding dimension.
@ -107,12 +107,12 @@ class array {
* bounds checking. */ * bounds checking. */
int shape(int dim) const { int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim); return shape().at(dim < 0 ? dim + ndim() : dim);
}; }
/** The strides of the array. */ /** The strides of the array. */
const std::vector<size_t>& strides() const { const std::vector<size_t>& strides() const {
return array_desc_->strides; return array_desc_->strides;
}; }
/** /**
* Get the stride of the corresponding dimension. * Get the stride of the corresponding dimension.
@ -121,12 +121,12 @@ class array {
* bounds checking. */ * bounds checking. */
size_t strides(int dim) const { size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim); return strides().at(dim < 0 ? dim + ndim() : dim);
}; }
/** Get the arrays data type. */ /** Get the arrays data type. */
Dtype dtype() const { Dtype dtype() const {
return array_desc_->dtype; return array_desc_->dtype;
}; }
/** Evaluate the array. */ /** Evaluate the array. */
void eval(); void eval();
@ -160,10 +160,10 @@ class array {
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx; return a.arr.id() == b.arr.id() && a.idx == b.idx;
}; }
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b); return !(a == b);
}; }
private: private:
const array& arr; const array& arr;
@ -209,7 +209,7 @@ class array {
allocator::Buffer buffer; allocator::Buffer buffer;
deleter_t d; deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free) Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}; : buffer(buffer), d(d) {}
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
@ -230,22 +230,22 @@ class array {
/** The array's primitive. */ /** The array's primitive. */
Primitive& primitive() const { Primitive& primitive() const {
return *(array_desc_->primitive); return *(array_desc_->primitive);
}; }
/** A shared pointer to the array's primitive. */ /** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const { std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive; return array_desc_->primitive;
}; }
/** Check if the array has an attached primitive or is a leaf node. */ /** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const { bool has_primitive() const {
return array_desc_->primitive != nullptr; return array_desc_->primitive != nullptr;
}; }
/** The array's inputs. */ /** The array's inputs. */
const std::vector<array>& inputs() const { const std::vector<array>& inputs() const {
return array_desc_->inputs; return array_desc_->inputs;
}; }
std::vector<array>& inputs() { std::vector<array>& inputs() {
return array_desc_->inputs; return array_desc_->inputs;
@ -259,12 +259,12 @@ class array {
/** The array's siblings. */ /** The array's siblings. */
const std::vector<array>& siblings() const { const std::vector<array>& siblings() const {
return array_desc_->siblings; return array_desc_->siblings;
}; }
/** The array's siblings. */ /** The array's siblings. */
std::vector<array>& siblings() { std::vector<array>& siblings() {
return array_desc_->siblings; return array_desc_->siblings;
}; }
void set_siblings(std::vector<array> siblings, uint16_t position) { void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings); array_desc_->siblings = std::move(siblings);
@ -281,7 +281,7 @@ class array {
outputs.push_back(*this); outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs; return outputs;
}; }
/** Detach the array from the graph. */ /** Detach the array from the graph. */
void detach(); void detach();
@ -289,19 +289,19 @@ class array {
/** Get the Flags bit-field. */ /** Get the Flags bit-field. */
const Flags& flags() const { const Flags& flags() const {
return array_desc_->flags; return array_desc_->flags;
}; }
/** The size (in elements) of the underlying buffer the array points to. */ /** The size (in elements) of the underlying buffer the array points to. */
size_t data_size() const { size_t data_size() const {
return array_desc_->data_size; return array_desc_->data_size;
}; }
allocator::Buffer& buffer() { allocator::Buffer& buffer() {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; }
const allocator::Buffer& buffer() const { const allocator::Buffer& buffer() const {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; }
// Return a copy of the shared pointer // Return a copy of the shared pointer
// to the array::Data struct // to the array::Data struct
@ -312,19 +312,20 @@ class array {
template <typename T> template <typename T>
T* data() { T* data() {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
}; }
template <typename T> template <typename T>
const T* data() const { const T* data() const {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
}; }
enum Status { unscheduled, scheduled, available }; enum Status { unscheduled, scheduled, available };
bool is_available() const { bool is_available() const {
return status() == Status::available; return status() == Status::available;
} }
const Status status() const {
Status status() const {
return array_desc_->status; return array_desc_->status;
} }

View File

@ -123,7 +123,7 @@ struct AccelerateSimdOps {
VT max(VT a, VT b) { VT max(VT a, VT b) {
return simd_max(a, b); return simd_max(a, b);
}; }
VT exp(VT x) { VT exp(VT x) {
return simd_fast_exp(x); return simd_fast_exp(x);
@ -170,7 +170,7 @@ struct NeonFp16SimdOps {
VT max(VT a, VT b) { VT max(VT a, VT b) {
return vmaxq_f16(a, b); return vmaxq_f16(a, b);
}; }
VT exp(VT x) { VT exp(VT x) {
return neon_fast_exp(x); return neon_fast_exp(x);

View File

@ -108,105 +108,105 @@ struct Abs {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::abs(x); return std::abs(x);
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
}; };
struct ArcCos { struct ArcCos {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::acos(x); return std::acos(x);
}; }
}; };
struct ArcCosh { struct ArcCosh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::acosh(x); return std::acosh(x);
}; }
}; };
struct ArcSin { struct ArcSin {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::asin(x); return std::asin(x);
}; }
}; };
struct ArcSinh { struct ArcSinh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::asinh(x); return std::asinh(x);
}; }
}; };
struct ArcTan { struct ArcTan {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::atan(x); return std::atan(x);
}; }
}; };
struct ArcTan2 { struct ArcTan2 {
template <typename T> template <typename T>
T operator()(T y, T x) { T operator()(T y, T x) {
return std::atan2(y, x); return std::atan2(y, x);
}; }
}; };
struct ArcTanh { struct ArcTanh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::atanh(x); return std::atanh(x);
}; }
}; };
struct Ceil { struct Ceil {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::ceil(x); return std::ceil(x);
}; }
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; }
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; }
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; }
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
}; };
struct Conjugate { struct Conjugate {
@ -219,35 +219,35 @@ struct Cos {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::cos(x); return std::cos(x);
}; }
}; };
struct Cosh { struct Cosh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::cosh(x); return std::cosh(x);
}; }
}; };
struct Erf { struct Erf {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x))); return static_cast<T>(fast_erf(static_cast<float>(x)));
}; }
}; };
struct ErfInv { struct ErfInv {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x))); return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}; }
}; };
struct Exp { struct Exp {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return fast_exp(x); return fast_exp(x);
}; }
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return std::exp(x); return std::exp(x);
@ -258,83 +258,83 @@ struct Expm1 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return expm1(x); return expm1(x);
}; }
}; };
struct Floor { struct Floor {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::floor(x); return std::floor(x);
}; }
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; }
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; }
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; }
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
}; };
struct Log { struct Log {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log(x); return std::log(x);
}; }
}; };
struct Log2 { struct Log2 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log2(x); return std::log2(x);
}; }
}; };
struct Log10 { struct Log10 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log10(x); return std::log10(x);
}; }
}; };
struct Log1p { struct Log1p {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return log1p(x); return log1p(x);
}; }
}; };
struct LogicalNot { struct LogicalNot {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return !x; return !x;
}; }
}; };
struct Negative { struct Negative {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return -x; return -x;
}; }
}; };
struct Round { struct Round {
@ -379,49 +379,49 @@ struct Sin {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sin(x); return std::sin(x);
}; }
}; };
struct Sinh { struct Sinh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sinh(x); return std::sinh(x);
}; }
}; };
struct Square { struct Square {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return x * x; return x * x;
}; }
}; };
struct Sqrt { struct Sqrt {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sqrt(x); return std::sqrt(x);
}; }
}; };
struct Rsqrt { struct Rsqrt {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x); return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}; }
}; };
struct Tan { struct Tan {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::tan(x); return std::tan(x);
}; }
}; };
struct Tanh { struct Tanh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::tanh(x); return std::tanh(x);
}; }
}; };
struct Add { struct Add {
@ -554,7 +554,7 @@ struct LogAddExp {
? maxval ? maxval
: static_cast<decltype(x)>( : static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval))); maxval + std::log1p(fast_exp(minval - maxval)));
}; }
}; };
struct Multiply { struct Multiply {
@ -602,14 +602,14 @@ struct LogicalAnd {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x && y; return x && y;
}; }
}; };
struct LogicalOr { struct LogicalOr {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x || y; return x || y;
}; }
}; };
struct Select { struct Select {
@ -623,35 +623,35 @@ struct BitwiseAnd {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x & y; return x & y;
}; }
}; };
struct BitwiseOr { struct BitwiseOr {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x | y; return x | y;
}; }
}; };
struct BitwiseXor { struct BitwiseXor {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x ^ y; return x ^ y;
}; }
}; };
struct LeftShift { struct LeftShift {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x << y; return x << y;
}; }
}; };
struct RightShift { struct RightShift {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x >> y; return x >> y;
}; }
}; };
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@ -23,7 +23,7 @@ template <typename U = bool>
struct And { struct And {
bool simd_reduce(bool val) { bool simd_reduce(bool val) {
return simd_all(val); return simd_all(val);
}; }
static constexpr constant bool init = true; static constexpr constant bool init = true;
@ -61,7 +61,7 @@ template <typename U = bool>
struct Or { struct Or {
bool simd_reduce(bool val) { bool simd_reduce(bool val) {
return simd_any(val); return simd_any(val);
}; }
static constexpr constant bool init = false; static constexpr constant bool init = false;
@ -100,7 +100,7 @@ struct Sum {
template <typename T> template <typename T>
T simd_reduce(T val) { T simd_reduce(T val) {
return simd_sum(val); return simd_sum(val);
}; }
static constexpr constant U init = U(0); static constexpr constant U init = U(0);
@ -120,7 +120,7 @@ struct Prod {
template <typename T> template <typename T>
T simd_reduce(T val) { T simd_reduce(T val) {
return simd_product(val); return simd_product(val);
}; }
static constexpr constant U init = U(1); static constexpr constant U init = U(1);
@ -140,7 +140,7 @@ struct Min {
template <typename T> template <typename T>
T simd_reduce(T val) { T simd_reduce(T val) {
return simd_min(val); return simd_min(val);
}; }
static constexpr constant U init = Limits<U>::max; static constexpr constant U init = Limits<U>::max;
@ -160,7 +160,7 @@ struct Max {
template <typename T> template <typename T>
T simd_reduce(T val) { T simd_reduce(T val) {
return simd_max(val); return simd_max(val);
}; }
static constexpr constant U init = Limits<U>::min; static constexpr constant U init = Limits<U>::min;

View File

@ -181,7 +181,7 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
} }
// Remove the source from the map to avoid fusing with it again // Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents); parents_map.erase(src_parents);
}; }
// Helper that merges two arrays in the graph by setting the parents of the // Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination. The arrays are assumed to be coming from // source to point to the destination. The arrays are assumed to be coming from
@ -194,7 +194,7 @@ void merge(array& dst, array& src, ParentsMap& parents_map) {
for (int i = 0; i < sources.size(); ++i) { for (int i = 0; i < sources.size(); ++i) {
merge_one(dests[i], sources[i], parents_map); merge_one(dests[i], sources[i], parents_map);
} }
}; }
template <typename T, typename... U> template <typename T, typename... U>
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) { std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
@ -260,7 +260,7 @@ class CompilerCache {
// Otherwise append a new cache entry // Otherwise append a new cache entry
entries.push_back(CacheEntry{}); entries.push_back(CacheEntry{});
return entries.back(); return entries.back();
}; }
void erase(std::uintptr_t fun_id) { void erase(std::uintptr_t fun_id) {
cache_.erase(fun_id); cache_.erase(fun_id);

View File

@ -13,7 +13,7 @@ struct Device {
static constexpr DeviceType cpu = DeviceType::cpu; static constexpr DeviceType cpu = DeviceType::cpu;
static constexpr DeviceType gpu = DeviceType::gpu; static constexpr DeviceType gpu = DeviceType::gpu;
Device(DeviceType type, int index = 0) : type(type), index(index) {}; Device(DeviceType type, int index = 0) : type(type), index(index) {}
DeviceType type; DeviceType type;
int index; int index;

View File

@ -51,10 +51,10 @@ struct Dtype {
Val val; Val val;
const uint8_t size; const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}; constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size) {}
constexpr operator Val() const { constexpr operator Val() const {
return val; return val;
}; }
}; };
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};

View File

@ -10,46 +10,46 @@ namespace mlx::core {
class Event { class Event {
public: public:
Event() {}; Event() = default;
Event(const Stream& steam); Event(const Stream& steam);
// Wait for the event to be signaled at its curent value // Wait for the event to be signaled at its current value
void wait(); void wait();
// Signal the event at its current value // Signal the event at its current value
void signal(); void signal();
// Check if the event is valid // Check if the event is valid
bool valid() { bool valid() const {
return event_ != nullptr; return event_ != nullptr;
}; }
uint64_t value() { uint64_t value() const {
return value_; return value_;
}; }
void set_value(uint64_t v) { void set_value(uint64_t v) {
value_ = v; value_ = v;
}; }
const Stream& stream() { const Stream& stream() const {
if (!valid()) { if (!valid()) {
throw std::runtime_error( throw std::runtime_error(
"[Event::stream] Cannot access stream on invalid event."); "[Event::stream] Cannot access stream on invalid event.");
} }
return stream_; return stream_;
}; }
const std::shared_ptr<void>& raw_event() { const std::shared_ptr<void>& raw_event() const {
return event_; return event_;
}; }
private: private:
// Default constructed stream should never be used // Default constructed stream should never be used
// since the event is not yet valid // since the event is not yet valid
Stream stream_{0, Device::cpu}; Stream stream_{0, Device::cpu};
std::shared_ptr<void> event_{nullptr}; std::shared_ptr<void> event_;
uint64_t value_{0}; uint64_t value_{0};
}; };

View File

@ -12,7 +12,7 @@ class Custom : public Primitive {
explicit Custom( explicit Custom(
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback) std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback) {}; : Primitive(stream), fallback_(fallback) {}
virtual std::pair<std::vector<array>, std::vector<int>> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@ -39,12 +39,12 @@ class RMSNorm : public Custom {
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {}; : Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -68,12 +68,12 @@ class RMSNormVJP : public Custom {
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {}; : Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -91,12 +91,12 @@ class LayerNorm : public Custom {
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {}; : Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -120,12 +120,12 @@ class LayerNormVJP : public Custom {
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {}; : Custom(stream, fallback), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -154,12 +154,12 @@ class RoPE : public Custom {
base_(base), base_(base),
scale_(scale), scale_(scale),
offset_(offset), offset_(offset),
forward_(forward) {}; forward_(forward) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -189,17 +189,17 @@ class ScaledDotProductAttention : public Custom {
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale, const float scale,
const bool needs_mask) const bool needs_mask)
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}; : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
}; }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
eval_gpu(inputs, outputs[0]); eval_gpu(inputs, outputs[0]);
}; }
void eval_gpu(const std::vector<array>& inputs, array& out); void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;

View File

@ -116,7 +116,7 @@ std::vector<array> Primitive::jvp(
print(msg); print(msg);
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; }
std::vector<array> Primitive::vjp( std::vector<array> Primitive::vjp(
const std::vector<array>&, const std::vector<array>&,
@ -128,7 +128,7 @@ std::vector<array> Primitive::vjp(
print(msg); print(msg);
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; }
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap( std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&, const std::vector<array>&,
@ -138,7 +138,7 @@ std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
print(msg); print(msg);
msg << "."; msg << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; }
std::vector<std::vector<int>> Primitive::output_shapes( std::vector<std::vector<int>> Primitive::output_shapes(
const std::vector<array>&) { const std::vector<array>&) {
@ -147,7 +147,7 @@ std::vector<std::vector<int>> Primitive::output_shapes(
this->print(msg); this->print(msg);
msg << " cannot infer output shapes."; msg << " cannot infer output shapes.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; }
std::vector<array> Abs::vjp( std::vector<array> Abs::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
@ -3430,7 +3430,7 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
return {{stop_gradient(inputs[0], stream())}, axes}; return {{stop_gradient(inputs[0], stream())}, axes};
}; }
std::vector<array> Subtract::vjp( std::vector<array> Subtract::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,

View File

@ -40,7 +40,7 @@
std::vector<std::vector<int>> output_shapes( \ std::vector<std::vector<int>> output_shapes( \
const std::vector<array>& inputs) override { \ const std::vector<array>& inputs) override { \
return {inputs[0].shape()}; \ return {inputs[0].shape()}; \
}; }
namespace mlx::core { namespace mlx::core {
@ -154,7 +154,7 @@ class UnaryPrimitive : public Primitive {
class Abs : public UnaryPrimitive { class Abs : public UnaryPrimitive {
public: public:
explicit Abs(Stream stream) : UnaryPrimitive(stream) {}; explicit Abs(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -171,7 +171,7 @@ class Abs : public UnaryPrimitive {
class Add : public UnaryPrimitive { class Add : public UnaryPrimitive {
public: public:
explicit Add(Stream stream) : UnaryPrimitive(stream) {}; explicit Add(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -189,7 +189,7 @@ class Add : public UnaryPrimitive {
class AddMM : public UnaryPrimitive { class AddMM : public UnaryPrimitive {
public: public:
explicit AddMM(Stream stream, float alpha, float beta) explicit AddMM(Stream stream, float alpha, float beta)
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}; : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -213,7 +213,7 @@ class AddMM : public UnaryPrimitive {
class Arange : public UnaryPrimitive { class Arange : public UnaryPrimitive {
public: public:
explicit Arange(Stream stream, double start, double stop, double step) explicit Arange(Stream stream, double start, double stop, double step)
: UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}; : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -231,7 +231,7 @@ class Arange : public UnaryPrimitive {
class ArcCos : public UnaryPrimitive { class ArcCos : public UnaryPrimitive {
public: public:
explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -248,7 +248,7 @@ class ArcCos : public UnaryPrimitive {
class ArcCosh : public UnaryPrimitive { class ArcCosh : public UnaryPrimitive {
public: public:
explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -265,7 +265,7 @@ class ArcCosh : public UnaryPrimitive {
class ArcSin : public UnaryPrimitive { class ArcSin : public UnaryPrimitive {
public: public:
explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -282,7 +282,7 @@ class ArcSin : public UnaryPrimitive {
class ArcSinh : public UnaryPrimitive { class ArcSinh : public UnaryPrimitive {
public: public:
explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -299,7 +299,7 @@ class ArcSinh : public UnaryPrimitive {
class ArcTan : public UnaryPrimitive { class ArcTan : public UnaryPrimitive {
public: public:
explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -316,7 +316,7 @@ class ArcTan : public UnaryPrimitive {
class ArcTan2 : public UnaryPrimitive { class ArcTan2 : public UnaryPrimitive {
public: public:
explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -333,7 +333,7 @@ class ArcTan2 : public UnaryPrimitive {
class ArcTanh : public UnaryPrimitive { class ArcTanh : public UnaryPrimitive {
public: public:
explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}; explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -351,7 +351,7 @@ class ArcTanh : public UnaryPrimitive {
class ArgPartition : public UnaryPrimitive { class ArgPartition : public UnaryPrimitive {
public: public:
explicit ArgPartition(Stream stream, int kth, int axis) explicit ArgPartition(Stream stream, int kth, int axis)
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {}; : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -376,7 +376,7 @@ class ArgReduce : public UnaryPrimitive {
}; };
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}; : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -397,7 +397,7 @@ class ArgReduce : public UnaryPrimitive {
class ArgSort : public UnaryPrimitive { class ArgSort : public UnaryPrimitive {
public: public:
explicit ArgSort(Stream stream, int axis) explicit ArgSort(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis) {}; : UnaryPrimitive(stream), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -416,7 +416,7 @@ class ArgSort : public UnaryPrimitive {
class AsType : public UnaryPrimitive { class AsType : public UnaryPrimitive {
public: public:
explicit AsType(Stream stream, Dtype dtype) explicit AsType(Stream stream, Dtype dtype)
: UnaryPrimitive(stream), dtype_(dtype) {}; : UnaryPrimitive(stream), dtype_(dtype) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -443,7 +443,7 @@ class AsStrided : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
shape_(std::move(shape)), shape_(std::move(shape)),
strides_(std::move(strides)), strides_(std::move(strides)),
offset_(offset) {}; offset_(offset) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -465,7 +465,7 @@ class BitwiseBinary : public UnaryPrimitive {
enum Op { And, Or, Xor, LeftShift, RightShift }; enum Op { And, Or, Xor, LeftShift, RightShift };
explicit BitwiseBinary(Stream stream, Op op) explicit BitwiseBinary(Stream stream, Op op)
: UnaryPrimitive(stream), op_(op) {}; : UnaryPrimitive(stream), op_(op) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -482,7 +482,7 @@ class BitwiseBinary : public UnaryPrimitive {
class BlockMaskedMM : public UnaryPrimitive { class BlockMaskedMM : public UnaryPrimitive {
public: public:
explicit BlockMaskedMM(Stream stream, int block_size) explicit BlockMaskedMM(Stream stream, int block_size)
: UnaryPrimitive(stream), block_size_(block_size) {}; : UnaryPrimitive(stream), block_size_(block_size) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -504,7 +504,7 @@ class BlockMaskedMM : public UnaryPrimitive {
class GatherMM : public UnaryPrimitive { class GatherMM : public UnaryPrimitive {
public: public:
explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}; explicit GatherMM(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -525,7 +525,7 @@ class GatherMM : public UnaryPrimitive {
class Broadcast : public UnaryPrimitive { class Broadcast : public UnaryPrimitive {
public: public:
explicit Broadcast(Stream stream, const std::vector<int>& shape) explicit Broadcast(Stream stream, const std::vector<int>& shape)
: UnaryPrimitive(stream), shape_(shape) {}; : UnaryPrimitive(stream), shape_(shape) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -543,7 +543,7 @@ class Broadcast : public UnaryPrimitive {
class Ceil : public UnaryPrimitive { class Ceil : public UnaryPrimitive {
public: public:
explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}; explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -604,7 +604,7 @@ class Compiled : public Primitive {
class Concatenate : public UnaryPrimitive { class Concatenate : public UnaryPrimitive {
public: public:
explicit Concatenate(Stream stream, int axis) explicit Concatenate(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis) {}; : UnaryPrimitive(stream), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -622,7 +622,7 @@ class Concatenate : public UnaryPrimitive {
class Conjugate : public UnaryPrimitive { class Conjugate : public UnaryPrimitive {
public: public:
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}; explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -652,7 +652,7 @@ class Convolution : public UnaryPrimitive {
kernel_dilation_(kernel_dilation), kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation), input_dilation_(input_dilation),
groups_(groups), groups_(groups),
flip_(flip) {}; flip_(flip) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -679,7 +679,7 @@ class Convolution : public UnaryPrimitive {
class Copy : public UnaryPrimitive { class Copy : public UnaryPrimitive {
public: public:
explicit Copy(Stream stream) : UnaryPrimitive(stream) {}; explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -696,7 +696,7 @@ class Copy : public UnaryPrimitive {
class Cos : public UnaryPrimitive { class Cos : public UnaryPrimitive {
public: public:
explicit Cos(Stream stream) : UnaryPrimitive(stream) {}; explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -713,7 +713,7 @@ class Cos : public UnaryPrimitive {
class Cosh : public UnaryPrimitive { class Cosh : public UnaryPrimitive {
public: public:
explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}; explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -784,7 +784,7 @@ class Depends : public Primitive {
class Divide : public UnaryPrimitive { class Divide : public UnaryPrimitive {
public: public:
explicit Divide(Stream stream) : UnaryPrimitive(stream) {}; explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -801,7 +801,7 @@ class Divide : public UnaryPrimitive {
class DivMod : public Primitive { class DivMod : public Primitive {
public: public:
explicit DivMod(Stream stream) : Primitive(stream) {}; explicit DivMod(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -815,7 +815,7 @@ class DivMod : public Primitive {
std::vector<std::vector<int>> output_shapes( std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override { const std::vector<array>& inputs) override {
return std::vector{inputs[0].shape(), inputs[0].shape()}; return std::vector{inputs[0].shape(), inputs[0].shape()};
}; }
private: private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs); void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
@ -823,7 +823,7 @@ class DivMod : public Primitive {
class Select : public UnaryPrimitive { class Select : public UnaryPrimitive {
public: public:
explicit Select(Stream stream) : UnaryPrimitive(stream) {}; explicit Select(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -840,7 +840,7 @@ class Select : public UnaryPrimitive {
class Remainder : public UnaryPrimitive { class Remainder : public UnaryPrimitive {
public: public:
explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}; explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -858,7 +858,7 @@ class Remainder : public UnaryPrimitive {
class Equal : public UnaryPrimitive { class Equal : public UnaryPrimitive {
public: public:
explicit Equal(Stream stream, bool equal_nan = false) explicit Equal(Stream stream, bool equal_nan = false)
: UnaryPrimitive(stream), equal_nan_(equal_nan) {}; : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -883,7 +883,7 @@ class Equal : public UnaryPrimitive {
class Erf : public UnaryPrimitive { class Erf : public UnaryPrimitive {
public: public:
explicit Erf(Stream stream) : UnaryPrimitive(stream) {}; explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -900,7 +900,7 @@ class Erf : public UnaryPrimitive {
class ErfInv : public UnaryPrimitive { class ErfInv : public UnaryPrimitive {
public: public:
explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}; explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -917,7 +917,7 @@ class ErfInv : public UnaryPrimitive {
class Exp : public UnaryPrimitive { class Exp : public UnaryPrimitive {
public: public:
explicit Exp(Stream stream) : UnaryPrimitive(stream) {}; explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -934,7 +934,7 @@ class Exp : public UnaryPrimitive {
class Expm1 : public UnaryPrimitive { class Expm1 : public UnaryPrimitive {
public: public:
explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}; explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -955,7 +955,7 @@ class FFT : public UnaryPrimitive {
const std::vector<size_t>& axes, const std::vector<size_t>& axes,
bool inverse, bool inverse,
bool real) bool real)
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}; : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -976,7 +976,7 @@ class FFT : public UnaryPrimitive {
class Floor : public UnaryPrimitive { class Floor : public UnaryPrimitive {
public: public:
explicit Floor(Stream stream) : UnaryPrimitive(stream) {}; explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -993,7 +993,7 @@ class Floor : public UnaryPrimitive {
class Full : public UnaryPrimitive { class Full : public UnaryPrimitive {
public: public:
explicit Full(Stream stream) : UnaryPrimitive(stream) {}; explicit Full(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1013,7 +1013,7 @@ class Gather : public UnaryPrimitive {
Stream stream, Stream stream,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& slice_sizes) const std::vector<int>& slice_sizes)
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}; : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1031,7 +1031,7 @@ class Gather : public UnaryPrimitive {
class Greater : public UnaryPrimitive { class Greater : public UnaryPrimitive {
public: public:
explicit Greater(Stream stream) : UnaryPrimitive(stream) {}; explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1048,7 +1048,7 @@ class Greater : public UnaryPrimitive {
class GreaterEqual : public UnaryPrimitive { class GreaterEqual : public UnaryPrimitive {
public: public:
explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}; explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1065,7 +1065,7 @@ class GreaterEqual : public UnaryPrimitive {
class Less : public UnaryPrimitive { class Less : public UnaryPrimitive {
public: public:
explicit Less(Stream stream) : UnaryPrimitive(stream) {}; explicit Less(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1082,7 +1082,7 @@ class Less : public UnaryPrimitive {
class LessEqual : public UnaryPrimitive { class LessEqual : public UnaryPrimitive {
public: public:
explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}; explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1107,7 +1107,7 @@ class Load : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
reader_(reader), reader_(reader),
offset_(offset), offset_(offset),
swap_endianness_(swap_endianness) {}; swap_endianness_(swap_endianness) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1126,7 +1126,7 @@ class Log : public UnaryPrimitive {
enum Base { two, ten, e }; enum Base { two, ten, e };
explicit Log(Stream stream, Base base) explicit Log(Stream stream, Base base)
: UnaryPrimitive(stream), base_(base) {}; : UnaryPrimitive(stream), base_(base) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1157,7 +1157,7 @@ class Log : public UnaryPrimitive {
class Log1p : public UnaryPrimitive { class Log1p : public UnaryPrimitive {
public: public:
explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}; explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1173,7 +1173,7 @@ class Log1p : public UnaryPrimitive {
class LogicalNot : public UnaryPrimitive { class LogicalNot : public UnaryPrimitive {
public: public:
explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}; explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1190,7 +1190,7 @@ class LogicalNot : public UnaryPrimitive {
class LogicalAnd : public UnaryPrimitive { class LogicalAnd : public UnaryPrimitive {
public: public:
explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}; explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1207,7 +1207,7 @@ class LogicalAnd : public UnaryPrimitive {
class LogicalOr : public UnaryPrimitive { class LogicalOr : public UnaryPrimitive {
public: public:
explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}; explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1224,7 +1224,7 @@ class LogicalOr : public UnaryPrimitive {
class LogAddExp : public UnaryPrimitive { class LogAddExp : public UnaryPrimitive {
public: public:
explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}; explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1241,7 +1241,7 @@ class LogAddExp : public UnaryPrimitive {
class Matmul : public UnaryPrimitive { class Matmul : public UnaryPrimitive {
public: public:
explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}; explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1259,7 +1259,7 @@ class Matmul : public UnaryPrimitive {
class Maximum : public UnaryPrimitive { class Maximum : public UnaryPrimitive {
public: public:
explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}; explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1276,7 +1276,7 @@ class Maximum : public UnaryPrimitive {
class Minimum : public UnaryPrimitive { class Minimum : public UnaryPrimitive {
public: public:
explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}; explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1293,7 +1293,7 @@ class Minimum : public UnaryPrimitive {
class Multiply : public UnaryPrimitive { class Multiply : public UnaryPrimitive {
public: public:
explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}; explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1310,7 +1310,7 @@ class Multiply : public UnaryPrimitive {
class Negative : public UnaryPrimitive { class Negative : public UnaryPrimitive {
public: public:
explicit Negative(Stream stream) : UnaryPrimitive(stream) {}; explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1327,7 +1327,7 @@ class Negative : public UnaryPrimitive {
class NotEqual : public UnaryPrimitive { class NotEqual : public UnaryPrimitive {
public: public:
explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}; explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1383,7 +1383,7 @@ class Pad : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
axes_(axes), axes_(axes),
low_pad_size_(low_pad_size), low_pad_size_(low_pad_size),
high_pad_size_(high_pad_size) {}; high_pad_size_(high_pad_size) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1404,7 +1404,7 @@ class Pad : public UnaryPrimitive {
class Partition : public UnaryPrimitive { class Partition : public UnaryPrimitive {
public: public:
explicit Partition(Stream stream, int kth, int axis) explicit Partition(Stream stream, int kth, int axis)
: UnaryPrimitive(stream), kth_(kth), axis_(axis) {}; : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1424,7 +1424,7 @@ class Partition : public UnaryPrimitive {
class Power : public UnaryPrimitive { class Power : public UnaryPrimitive {
public: public:
explicit Power(Stream stream) : UnaryPrimitive(stream) {}; explicit Power(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1449,7 +1449,7 @@ class QuantizedMatmul : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {}; transpose_(transpose) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1473,7 +1473,7 @@ class GatherQMM : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {}; transpose_(transpose) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1494,7 +1494,7 @@ class GatherQMM : public UnaryPrimitive {
class RandomBits : public UnaryPrimitive { class RandomBits : public UnaryPrimitive {
public: public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width) explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
: UnaryPrimitive(stream), shape_(shape), width_(width) {}; : UnaryPrimitive(stream), shape_(shape), width_(width) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1513,7 +1513,7 @@ class RandomBits : public UnaryPrimitive {
class Reshape : public UnaryPrimitive { class Reshape : public UnaryPrimitive {
public: public:
explicit Reshape(Stream stream, const std::vector<int>& shape) explicit Reshape(Stream stream, const std::vector<int>& shape)
: UnaryPrimitive(stream), shape_(shape) {}; : UnaryPrimitive(stream), shape_(shape) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1545,7 +1545,7 @@ class Reduce : public UnaryPrimitive {
Stream stream, Stream stream,
ReduceType reduce_type, ReduceType reduce_type,
const std::vector<int>& axes) const std::vector<int>& axes)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}; : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1594,7 +1594,7 @@ class Reduce : public UnaryPrimitive {
class Round : public UnaryPrimitive { class Round : public UnaryPrimitive {
public: public:
explicit Round(Stream stream) : UnaryPrimitive(stream) {}; explicit Round(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1623,7 +1623,7 @@ class Scan : public UnaryPrimitive {
reduce_type_(reduce_type), reduce_type_(reduce_type),
axis_(axis), axis_(axis),
reverse_(reverse), reverse_(reverse),
inclusive_(inclusive) {}; inclusive_(inclusive) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1667,7 +1667,7 @@ class Scatter : public UnaryPrimitive {
Stream stream, Stream stream,
ReduceType reduce_type, ReduceType reduce_type,
const std::vector<int>& axes) const std::vector<int>& axes)
: UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}; : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1702,7 +1702,7 @@ class Scatter : public UnaryPrimitive {
class Sigmoid : public UnaryPrimitive { class Sigmoid : public UnaryPrimitive {
public: public:
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}; explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1719,7 +1719,7 @@ class Sigmoid : public UnaryPrimitive {
class Sign : public UnaryPrimitive { class Sign : public UnaryPrimitive {
public: public:
explicit Sign(Stream stream) : UnaryPrimitive(stream) {}; explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1736,7 +1736,7 @@ class Sign : public UnaryPrimitive {
class Sin : public UnaryPrimitive { class Sin : public UnaryPrimitive {
public: public:
explicit Sin(Stream stream) : UnaryPrimitive(stream) {}; explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1753,7 +1753,7 @@ class Sin : public UnaryPrimitive {
class Sinh : public UnaryPrimitive { class Sinh : public UnaryPrimitive {
public: public:
explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}; explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1778,7 +1778,7 @@ class Slice : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
start_indices_(start_indices), start_indices_(start_indices),
end_indices_(end_indices), end_indices_(end_indices),
strides_(strides) {}; strides_(strides) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1806,7 +1806,7 @@ class SliceUpdate : public UnaryPrimitive {
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
start_indices_(start_indices), start_indices_(start_indices),
end_indices_(end_indices), end_indices_(end_indices),
strides_(strides) {}; strides_(strides) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1829,7 +1829,7 @@ class SliceUpdate : public UnaryPrimitive {
class Softmax : public UnaryPrimitive { class Softmax : public UnaryPrimitive {
public: public:
explicit Softmax(Stream stream, bool precise) explicit Softmax(Stream stream, bool precise)
: UnaryPrimitive(stream), precise_(precise) {}; : UnaryPrimitive(stream), precise_(precise) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1849,7 +1849,7 @@ class Softmax : public UnaryPrimitive {
class Sort : public UnaryPrimitive { class Sort : public UnaryPrimitive {
public: public:
explicit Sort(Stream stream, int axis) explicit Sort(Stream stream, int axis)
: UnaryPrimitive(stream), axis_(axis) {}; : UnaryPrimitive(stream), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1869,7 +1869,7 @@ class Sort : public UnaryPrimitive {
class Split : public Primitive { class Split : public Primitive {
public: public:
explicit Split(Stream stream, const std::vector<int>& indices, int axis) explicit Split(Stream stream, const std::vector<int>& indices, int axis)
: Primitive(stream), indices_(indices), axis_(axis) {}; : Primitive(stream), indices_(indices), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -1890,7 +1890,7 @@ class Split : public Primitive {
class Square : public UnaryPrimitive { class Square : public UnaryPrimitive {
public: public:
explicit Square(Stream stream) : UnaryPrimitive(stream) {}; explicit Square(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1908,7 +1908,7 @@ class Square : public UnaryPrimitive {
class Sqrt : public UnaryPrimitive { class Sqrt : public UnaryPrimitive {
public: public:
explicit Sqrt(Stream stream, bool recip = false) explicit Sqrt(Stream stream, bool recip = false)
: UnaryPrimitive(stream), recip_(recip) {}; : UnaryPrimitive(stream), recip_(recip) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1933,7 +1933,7 @@ class Sqrt : public UnaryPrimitive {
class StopGradient : public UnaryPrimitive { class StopGradient : public UnaryPrimitive {
public: public:
explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}; explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1949,7 +1949,7 @@ class StopGradient : public UnaryPrimitive {
class Subtract : public UnaryPrimitive { class Subtract : public UnaryPrimitive {
public: public:
explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}; explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1966,7 +1966,7 @@ class Subtract : public UnaryPrimitive {
class Tan : public UnaryPrimitive { class Tan : public UnaryPrimitive {
public: public:
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}; explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1983,7 +1983,7 @@ class Tan : public UnaryPrimitive {
class Tanh : public UnaryPrimitive { class Tanh : public UnaryPrimitive {
public: public:
explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}; explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2000,7 +2000,7 @@ class Tanh : public UnaryPrimitive {
class Uniform : public UnaryPrimitive { class Uniform : public UnaryPrimitive {
public: public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}; explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2016,7 +2016,7 @@ class Uniform : public UnaryPrimitive {
class View : public UnaryPrimitive { class View : public UnaryPrimitive {
public: public:
explicit View(Stream stream, Dtype dtype) explicit View(Stream stream, Dtype dtype)
: UnaryPrimitive(stream), dtype_(dtype) {}; : UnaryPrimitive(stream), dtype_(dtype) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2032,7 +2032,7 @@ class View : public UnaryPrimitive {
class Transpose : public UnaryPrimitive { class Transpose : public UnaryPrimitive {
public: public:
explicit Transpose(Stream stream, const std::vector<int>& axes) explicit Transpose(Stream stream, const std::vector<int>& axes)
: UnaryPrimitive(stream), axes_(axes) {}; : UnaryPrimitive(stream), axes_(axes) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2051,7 +2051,7 @@ class Transpose : public UnaryPrimitive {
/* QR Factorization primitive. */ /* QR Factorization primitive. */
class QRF : public Primitive { class QRF : public Primitive {
public: public:
explicit QRF(Stream stream) : Primitive(stream) {}; explicit QRF(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -2067,7 +2067,7 @@ class QRF : public Primitive {
/* SVD primitive. */ /* SVD primitive. */
class SVD : public Primitive { class SVD : public Primitive {
public: public:
explicit SVD(Stream stream) : Primitive(stream) {}; explicit SVD(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
@ -2084,7 +2084,7 @@ class SVD : public Primitive {
/* Matrix inversion primitive. */ /* Matrix inversion primitive. */
class Inverse : public UnaryPrimitive { class Inverse : public UnaryPrimitive {
public: public:
explicit Inverse(Stream stream) : UnaryPrimitive(stream) {}; explicit Inverse(Stream stream) : UnaryPrimitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, array& output) override; void eval_cpu(const std::vector<array>& inputs, array& output) override;
void eval_gpu(const std::vector<array>& inputs, array& output) override; void eval_gpu(const std::vector<array>& inputs, array& output) override;
@ -2099,7 +2099,7 @@ class Inverse : public UnaryPrimitive {
class Cholesky : public UnaryPrimitive { class Cholesky : public UnaryPrimitive {
public: public:
explicit Cholesky(Stream stream, bool upper) explicit Cholesky(Stream stream, bool upper)
: UnaryPrimitive(stream), upper_(upper) {}; : UnaryPrimitive(stream), upper_(upper) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;

View File

@ -148,7 +148,7 @@ array randint(
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return randint(array(low), array(high), shape, dtype, key, to_stream(s)); return randint(array(low), array(high), shape, dtype, key, to_stream(s));
}; }
/** Generate binary variables with probability to be true equal to p */ /** Generate binary variables with probability to be true equal to p */
array bernoulli( array bernoulli(
@ -167,7 +167,7 @@ array bernoulli(
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return bernoulli(array(p), key, s); return bernoulli(array(p), key, s);
}; }
template <typename T> template <typename T>
array bernoulli( array bernoulli(
@ -176,7 +176,7 @@ array bernoulli(
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return bernoulli(array(p), shape, key, s); return bernoulli(array(p), shape, key, s);
}; }
array bernoulli( array bernoulli(
const std::optional<array>& key = std::nullopt, const std::optional<array>& key = std::nullopt,

View File

@ -31,7 +31,7 @@ struct StreamThread {
~StreamThread() { ~StreamThread() {
synchronize(stream); synchronize(stream);
{ {
std::unique_lock<std::mutex> lk(mtx); std::lock_guard<std::mutex> lk(mtx);
stop = true; stop = true;
} }
cond.notify_one(); cond.notify_one();
@ -58,7 +58,7 @@ struct StreamThread {
template <typename F> template <typename F>
void enqueue(F&& f) { void enqueue(F&& f) {
{ {
std::unique_lock<std::mutex> lk(mtx); std::lock_guard<std::mutex> lk(mtx);
if (stop) { if (stop) {
throw std::runtime_error( throw std::runtime_error(
"Cannot enqueue work after stream is stopped."); "Cannot enqueue work after stream is stopped.");
@ -93,7 +93,7 @@ class Scheduler {
template <typename F> template <typename F>
void enqueue(const Stream& stream, F&& f); void enqueue(const Stream& stream, F&& f);
Stream get_default_stream(const Device& d) { Stream get_default_stream(const Device& d) const {
return default_streams_.at(d.type); return default_streams_.at(d.type);
} }
@ -103,7 +103,7 @@ class Scheduler {
void notify_new_task(const Stream& stream) { void notify_new_task(const Stream& stream) {
{ {
std::unique_lock<std::mutex> lk(mtx); std::lock_guard<std::mutex> lk(mtx);
n_active_tasks_++; n_active_tasks_++;
} }
completion_cv.notify_all(); completion_cv.notify_all();
@ -111,7 +111,7 @@ class Scheduler {
void notify_task_completion(const Stream& stream) { void notify_task_completion(const Stream& stream) {
{ {
std::unique_lock<std::mutex> lk(mtx); std::lock_guard<std::mutex> lk(mtx);
n_active_tasks_--; n_active_tasks_--;
} }
completion_cv.notify_all(); completion_cv.notify_all();

View File

@ -22,10 +22,10 @@ namespace mlx::core {
* for synchronizing with the main thread. */ * for synchronizing with the main thread. */
class Synchronizer : public Primitive { class Synchronizer : public Primitive {
public: public:
explicit Synchronizer(Stream stream) : Primitive(stream) {}; explicit Synchronizer(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}; void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}; void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
DEFINE_PRINT(Synchronize); DEFINE_PRINT(Synchronize);
}; };