MLX
 
Loading...
Searching...
No Matches
primitives.h File Reference
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/io/load.h"
#include "mlx/stream.h"

Go to the source code of this file.

Classes

class  mlx::core::Primitive
 
class  mlx::core::UnaryPrimitive
 
class  mlx::core::Abs
 
class  mlx::core::Add
 
class  mlx::core::AddMM
 
class  mlx::core::Arange
 
class  mlx::core::ArcCos
 
class  mlx::core::ArcCosh
 
class  mlx::core::ArcSin
 
class  mlx::core::ArcSinh
 
class  mlx::core::ArcTan
 
class  mlx::core::ArcTan2
 
class  mlx::core::ArcTanh
 
class  mlx::core::ArgPartition
 
class  mlx::core::ArgReduce
 
class  mlx::core::ArgSort
 
class  mlx::core::AsType
 
class  mlx::core::AsStrided
 
class  mlx::core::BitwiseBinary
 
class  mlx::core::BitwiseInvert
 
class  mlx::core::BlockMaskedMM
 
class  mlx::core::GatherMM
 
class  mlx::core::BroadcastAxes
 
class  mlx::core::Broadcast
 
class  mlx::core::Ceil
 
class  mlx::core::Compiled
 
class  mlx::core::Concatenate
 
class  mlx::core::Conjugate
 
class  mlx::core::Contiguous
 
class  mlx::core::Convolution
 
class  mlx::core::Copy
 
class  mlx::core::Cos
 
class  mlx::core::Cosh
 
class  mlx::core::CustomTransforms
 
class  mlx::core::Depends
 
class  mlx::core::Divide
 
class  mlx::core::DivMod
 
class  mlx::core::Select
 
class  mlx::core::Remainder
 
class  mlx::core::Equal
 
class  mlx::core::Erf
 
class  mlx::core::ErfInv
 
class  mlx::core::Exp
 
class  mlx::core::Expm1
 
class  mlx::core::ExpandDims
 
class  mlx::core::FFT
 
class  mlx::core::Flatten
 
class  mlx::core::Floor
 
class  mlx::core::Full
 
class  mlx::core::Gather
 
class  mlx::core::GatherAxis
 
class  mlx::core::Greater
 
class  mlx::core::GreaterEqual
 
class  mlx::core::Hadamard
 
class  mlx::core::Imag
 
class  mlx::core::Less
 
class  mlx::core::LessEqual
 
class  mlx::core::Load
 
class  mlx::core::Log
 
class  mlx::core::Log1p
 
class  mlx::core::LogicalNot
 
class  mlx::core::LogicalAnd
 
class  mlx::core::LogicalOr
 
class  mlx::core::LogAddExp
 
class  mlx::core::Matmul
 
class  mlx::core::Maximum
 
class  mlx::core::Minimum
 
class  mlx::core::Multiply
 
class  mlx::core::Negative
 
class  mlx::core::NotEqual
 
class  mlx::core::NumberOfElements
 
class  mlx::core::Pad
 
class  mlx::core::Partition
 
class  mlx::core::Power
 
class  mlx::core::QuantizedMatmul
 
class  mlx::core::GatherQMM
 
class  mlx::core::RandomBits
 
class  mlx::core::Real
 
class  mlx::core::Reshape
 
class  mlx::core::Reduce
 
class  mlx::core::Round
 
class  mlx::core::Scan
 
class  mlx::core::Scatter
 
class  mlx::core::ScatterAxis
 
class  mlx::core::Sigmoid
 
class  mlx::core::Sign
 
class  mlx::core::Sin
 
class  mlx::core::Sinh
 
class  mlx::core::Slice
 
class  mlx::core::SliceUpdate
 
class  mlx::core::DynamicSlice
 
class  mlx::core::DynamicSliceUpdate
 
class  mlx::core::Softmax
 
class  mlx::core::Sort
 
class  mlx::core::Split
 
class  mlx::core::Square
 
class  mlx::core::Sqrt
 
class  mlx::core::StopGradient
 
class  mlx::core::Subtract
 
class  mlx::core::Squeeze
 
class  mlx::core::Tan
 
class  mlx::core::Tanh
 
class  mlx::core::Unflatten
 
class  mlx::core::View
 
class  mlx::core::Transpose
 
class  mlx::core::QRF
 
class  mlx::core::SVD
 
class  mlx::core::Inverse
 
class  mlx::core::Cholesky
 
class  mlx::core::Eigh
 
class  mlx::core::LUF
 

Namespaces

namespace  mlx
 
namespace  mlx::core
 

Macros

#define DEFINE_VMAP()
 
#define DEFINE_GRADS()
 
#define DEFINE_PRINT(PRIMITIVE)
 
#define DEFINE_DEFAULT_IS_EQUIVALENT()
 
#define DEFINE_INPUT_OUTPUT_SHAPE()
 

Macro Definition Documentation

◆ DEFINE_DEFAULT_IS_EQUIVALENT

#define DEFINE_DEFAULT_IS_EQUIVALENT ( )
Value:
bool is_equivalent(const Primitive& other) const override { \
return true; \
}

◆ DEFINE_GRADS

#define DEFINE_GRADS ( )
Value:
std::vector<array> jvp( \
const std::vector<array>& primals, \
const std::vector<array>& tangents, \
const std::vector<int>& argnums) override; \
\
std::vector<array> vjp( \
const std::vector<array>& primals, \
const std::vector<array>& cotangents, \
const std::vector<int>& argnums, \
const std::vector<array>& outputs) override;

◆ DEFINE_INPUT_OUTPUT_SHAPE

#define DEFINE_INPUT_OUTPUT_SHAPE ( )
Value:
std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
override { \
return {inputs[0].shape()}; \
}

◆ DEFINE_PRINT

#define DEFINE_PRINT ( PRIMITIVE)
Value:
void print(std::ostream& os) override { \
os << #PRIMITIVE; \
}

◆ DEFINE_VMAP

#define DEFINE_VMAP ( )
Value:
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
const std::vector<array>& inputs, const std::vector<int>& axes) \
override;