// Copyright © 2023 Apple Inc. #pragma once #include "array.h" #include "device.h" #include "load.h" #include "stream.h" #define DEFINE_GRADS() \ array jvp( \ const std::vector& primals, \ const std::vector& tangents, \ const std::vector& argnums) override; \ \ std::vector vjp( \ const std::vector& primals, \ const array& cotan, \ const std::vector& argnums) override; #define DEFINE_PRINT(PRIMITIVE) \ void print(std::ostream& os) override { \ os << #PRIMITIVE; \ } #define DEFINE_DEFAULT_IS_EQUIVALENT() \ bool is_equivalent(const Primitive& other) const override { \ return true; \ } namespace mlx::core { // Abstract base class class Primitive { public: explicit Primitive(Stream stream) : stream_(stream) {} /** The device the primitive will run on. */ const Device& device() { return stream().device; } /** The stream the primitive will run on. */ const Stream& stream() { return stream_; } /** * A primitive must know how to evaluate itself on * the CPU/GPU for the given inputs and populate the output array. * * To avoid unecessary allocations, the evaluation function * is responsible for allocating space for the array. */ virtual void eval_cpu(const std::vector& inputs, array& out) = 0; virtual void eval_gpu(const std::vector& inputs, array& out) = 0; /** * The Jacobian-vector product. */ virtual array jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums); /** * The vector-Jacobian product. */ virtual std::vector vjp( const std::vector& primals, const array& cotan, const std::vector& argnums); /** * The primitive must know how to vectorize itself accross * the given axes. The output is a pair containing the array * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ virtual std::pair vmap( const std::vector& inputs, const std::vector& axes); /** Print the primitive. */ virtual void print(std::ostream& os) = 0; /** Equivalence check defaults to false unless overriden by the primitive */ virtual bool is_equivalent(const Primitive& other) const { return false; } virtual ~Primitive() = default; Primitive(const Primitive& other) = delete; Primitive(Primitive&& other) = delete; Primitive& operator=(const Primitive& other) = delete; Primitive& operator=(Primitive&& other) = delete; private: // Every primitive stores the stream it should run in Stream stream_; }; class Abs : public Primitive { public: explicit Abs(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Abs) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Add : public Primitive { public: explicit Add(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Add) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Arange : public Primitive { public: explicit Arange(Stream stream, double start, double stop, double step) : Primitive(stream), start_(start), stop_(stop), step_(step){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_PRINT(Arange) bool is_equivalent(const Primitive& other) const override; private: double start_; double stop_; double step_; void eval(const std::vector& inputs, array& out); }; class ArcCos : public Primitive { public: explicit ArcCos(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcCos) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArcCosh : public Primitive { public: explicit ArcCosh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcCosh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArcSin : public Primitive { public: explicit ArcSin(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcSin) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArcSinh : public Primitive { public: explicit ArcSinh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcSinh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArcTan : public Primitive { public: explicit ArcTan(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcTan) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArcTanh : public Primitive { public: explicit ArcTanh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ArcTanh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ArgPartition : public Primitive { public: explicit ArgPartition(Stream stream, int kth, int axis) : Primitive(stream), kth_(kth), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(ArgPartition) bool is_equivalent(const Primitive& other) const override; private: int kth_; int axis_; void eval(const std::vector& inputs, array& out); }; class ArgReduce : public Primitive { public: enum ReduceType { ArgMin, ArgMax, }; explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) : Primitive(stream), reduce_type_(reduce_type), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; private: ReduceType reduce_type_; int axis_; void eval(const std::vector& inputs, array& out); }; class ArgSort : public Primitive { public: explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(ArgSort) bool is_equivalent(const Primitive& other) const override; private: int axis_; void eval(const std::vector& inputs, array& out); }; class AsType : public Primitive { public: explicit AsType(Stream stream, Dtype dtype) : Primitive(stream), dtype_(dtype){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(AsType) bool is_equivalent(const Primitive& other) const override; private: Dtype dtype_; void eval(const std::vector& inputs, array& out); }; class AsStrided : public Primitive { public: explicit AsStrided( Stream stream, const std::vector& shape, const std::vector& strides, size_t offset) : Primitive(stream), shape_(shape), strides_(strides), offset_(offset){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_GRADS() DEFINE_PRINT(AsStrided) bool is_equivalent(const Primitive& other) const override; private: std::vector shape_; std::vector strides_; size_t offset_; void eval(const std::vector& inputs, array& out); }; class Broadcast : public Primitive { public: explicit Broadcast(Stream stream, const std::vector& shape) : Primitive(stream), shape_(shape){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Broadcast) bool is_equivalent(const Primitive& other) const override; private: std::vector shape_; void eval(const std::vector& inputs, array& out); }; class Concatenate : public Primitive { public: explicit Concatenate(Stream stream, int axis) : Primitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Concatenate) bool is_equivalent(const Primitive& other) const override; private: int axis_; void eval(const std::vector& inputs, array& out); }; class Convolution : public Primitive { public: explicit Convolution( Stream stream, const std::vector& padding, const std::vector& kernel_strides, const std::vector& kernel_dilation, const std::vector& input_dilation) : Primitive(stream), padding_(padding), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::vector vjp( const std::vector& primals, const array& cotan, const std::vector& argnums) override; DEFINE_PRINT(Convolution) bool is_equivalent(const Primitive& other) const override; private: std::vector padding_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; void eval(const std::vector& inputs, array& out); }; class Copy : public Primitive { public: explicit Copy(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Copy) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Cos : public Primitive { public: explicit Cos(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Cos) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Cosh : public Primitive { public: explicit Cosh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Cosh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Divide : public Primitive { public: explicit Divide(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Divide) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Equal : public Primitive { public: explicit Equal(Stream stream, bool equal_nan = false) : Primitive(stream), equal_nan_(equal_nan){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Equal) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); bool equal_nan_; }; class Erf : public Primitive { public: explicit Erf(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Erf) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class ErfInv : public Primitive { public: explicit ErfInv(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(ErfInv) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Exp : public Primitive { public: explicit Exp(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Exp) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class FFT : public Primitive { public: explicit FFT( Stream stream, const std::vector& axes, bool inverse, bool real) : Primitive(stream), axes_(axes), inverse_(inverse), real_(real){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(FFT) bool is_equivalent(const Primitive& other) const override; private: std::vector axes_; bool inverse_; bool real_; void eval(const std::vector& inputs, array& out); }; class Full : public Primitive { public: explicit Full(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Full) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Gather : public Primitive { public: explicit Gather( Stream stream, const std::vector& axes, const std::vector& slice_sizes) : Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Gather) bool is_equivalent(const Primitive& other) const override; private: void eval(const std::vector& inputs, array& out); std::vector axes_; std::vector slice_sizes_; }; class Greater : public Primitive { public: explicit Greater(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Greater) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class GreaterEqual : public Primitive { public: explicit GreaterEqual(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(GreaterEqual) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Less : public Primitive { public: explicit Less(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Less) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class LessEqual : public Primitive { public: explicit LessEqual(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(LessEqual) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Load : public Primitive { public: explicit Load( Stream stream, std::shared_ptr reader, size_t offset, bool swap_endianness = false) : Primitive(stream), reader_(reader), offset_(offset), swap_endianness_(swap_endianness){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_PRINT(Load) private: void eval(const std::vector& inputs, array& out); std::shared_ptr reader_; size_t offset_; bool swap_endianness_; }; class Log : public Primitive { public: enum Base { two, ten, e }; explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Log) DEFINE_DEFAULT_IS_EQUIVALENT() private: Base base_; void eval(const std::vector& inputs, array& out); }; class Log1p : public Primitive { public: explicit Log1p(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Log1p) private: void eval(const std::vector& inputs, array& out); }; class LogicalNot : public Primitive { public: explicit LogicalNot(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(LogicalNot) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class LogAddExp : public Primitive { public: explicit LogAddExp(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(LogAddExp) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Matmul : public Primitive { public: explicit Matmul(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::vector vjp( const std::vector& primals, const array& cotan, const std::vector& argnums) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() }; class Maximum : public Primitive { public: explicit Maximum(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Maximum) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Minimum : public Primitive { public: explicit Minimum(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Minimum) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Multiply : public Primitive { public: explicit Multiply(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Multiply) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Negative : public Primitive { public: explicit Negative(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Negative) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class NotEqual : public Primitive { public: explicit NotEqual(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(NotEqual) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Pad : public Primitive { public: explicit Pad( Stream stream, const std::vector& axes, const std::vector& low_pad_size, const std::vector& high_pad_size) : Primitive(stream), axes_(axes), low_pad_size_(low_pad_size), high_pad_size_(high_pad_size){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Pad) bool is_equivalent(const Primitive& other) const override; private: std::vector axes_; std::vector low_pad_size_; std::vector high_pad_size_; void eval(const std::vector& inputs, array& out); }; class Partition : public Primitive { public: explicit Partition(Stream stream, int kth, int axis) : Primitive(stream), kth_(kth), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Partition) bool is_equivalent(const Primitive& other) const override; private: int kth_; int axis_; void eval(const std::vector& inputs, array& out); }; class Power : public Primitive { public: explicit Power(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Power) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class RandomBits : public Primitive { public: explicit RandomBits(Stream stream, const std::vector& shape, int width) : Primitive(stream), shape_(shape), width_(width){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(RandomBits) bool is_equivalent(const Primitive& other) const override; private: std::vector shape_; int width_; void eval(const std::vector& inputs, array& out); }; class Reshape : public Primitive { public: explicit Reshape(Stream stream, const std::vector& shape) : Primitive(stream), shape_(shape){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Reshape) bool is_equivalent(const Primitive& other) const override; private: std::vector shape_; void eval(const std::vector& inputs, array& out); }; class Reduce : public Primitive { public: enum ReduceType { And, Or, Sum, Prod, Min, Max }; explicit Reduce( Stream stream, ReduceType reduce_type, const std::vector& axes) : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; std::vector vjp( const std::vector& primals, const array& cotan, const std::vector& argnums) override; void print(std::ostream& os) override { switch (reduce_type_) { case And: os << "And"; case Or: os << "And"; break; case Sum: os << "Sum"; break; case Prod: os << "Prod"; break; case Min: os << "Min"; break; case Max: os << "Max"; break; } os << " Reduce"; } bool is_equivalent(const Primitive& other) const override; private: ReduceType reduce_type_; std::vector axes_; void eval(const std::vector& inputs, array& out); }; class Scan : public Primitive { public: enum ReduceType { Max, Min, Sum, Prod }; explicit Scan( Stream stream, ReduceType reduce_type, int axis, bool reverse, bool inclusive) : Primitive(stream), reduce_type_(reduce_type), axis_(axis), reverse_(reverse), inclusive_(inclusive){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS(); void print(std::ostream& os) override { os << "Cum"; switch (reduce_type_) { case Sum: os << "Sum"; break; case Prod: os << "Prod"; break; case Min: os << "Min"; break; case Max: os << "Max"; break; } os << " Reduce"; } bool is_equivalent(const Primitive& other) const override; private: ReduceType reduce_type_; int axis_; bool reverse_; bool inclusive_; void eval(const std::vector& inputs, array& out); }; class Scatter : public Primitive { public: enum ReduceType { Max, Min, Sum, Prod, None }; explicit Scatter( Stream stream, ReduceType reduce_type, const std::vector& axes) : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_PRINT(Scatter) bool is_equivalent(const Primitive& other) const override; private: void eval(const std::vector& inputs, array& out); ReduceType reduce_type_; std::vector axes_; }; class Sigmoid : public Primitive { public: explicit Sigmoid(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sigmoid) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Sign : public Primitive { public: explicit Sign(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sign) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Sin : public Primitive { public: explicit Sin(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sin) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Sinh : public Primitive { public: explicit Sinh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sinh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Slice : public Primitive { public: explicit Slice( Stream stream, const std::vector& start_indices, const std::vector& end_indices, const std::vector& strides) : Primitive(stream), start_indices_(start_indices), end_indices_(end_indices), strides_(strides){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Slice) bool is_equivalent(const Primitive& other) const override; private: std::vector start_indices_; std::vector end_indices_; std::vector strides_; void eval(const std::vector& inputs, array& out); }; class Softmax : public Primitive { public: explicit Softmax(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Softmax) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Sort : public Primitive { public: explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sort) bool is_equivalent(const Primitive& other) const override; private: int axis_; void eval(const std::vector& inputs, array& out); }; class Square : public Primitive { public: explicit Square(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Square) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Sqrt : public Primitive { public: explicit Sqrt(Stream stream, bool recip = false) : Primitive(stream), recip_(recip){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Sqrt) bool is_equivalent(const Primitive& other) const override; private: void eval(const std::vector& inputs, array& out); bool recip_; }; class StopGradient : public Primitive { public: explicit StopGradient(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(StopGradient) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Subtract : public Primitive { public: explicit Subtract(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Subtract) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Tan : public Primitive { public: explicit Tan(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Tan) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Tanh : public Primitive { public: explicit Tanh(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Tanh) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Uniform : public Primitive { public: explicit Uniform(Stream stream) : Primitive(stream){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_PRINT(Uniform) DEFINE_DEFAULT_IS_EQUIVALENT() private: void eval(const std::vector& inputs, array& out); }; class Transpose : public Primitive { public: explicit Transpose(Stream stream, const std::vector& axes) : Primitive(stream), axes_(axes){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; std::pair vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_GRADS() DEFINE_PRINT(Transpose) bool is_equivalent(const Primitive& other) const override; private: std::vector axes_; void eval(const std::vector& inputs, array& out); }; } // namespace mlx::core