MLX
Loading...
Searching...
No Matches
Namespaces | Classes | Typedefs | Enumerations | Functions | Variables
mlx::core Namespace Reference

Namespaces

namespace  allocator
 
namespace  detail
 
namespace  distributed
 
namespace  fast
 
namespace  fft
 
namespace  io
 
namespace  linalg
 
namespace  metal
 
namespace  random
 
namespace  scheduler
 

Classes

struct  _MLX_BFloat16
 
struct  _MLX_Float16
 
class  Abs
 
class  Add
 
class  AddMM
 
class  Arange
 
class  ArcCos
 
class  ArcCosh
 
class  ArcSin
 
class  ArcSinh
 
class  ArcTan
 
class  ArcTan2
 
class  ArcTanh
 
class  ArgPartition
 
class  ArgReduce
 
class  ArgSort
 
class  array
 
class  AsStrided
 
class  AsType
 
class  BitwiseBinary
 
class  BlockMaskedMM
 
class  Broadcast
 
class  Ceil
 
class  Cholesky
 
class  Compiled
 
struct  complex128_t
 
struct  complex64_t
 
class  Concatenate
 
class  Conjugate
 
class  Convolution
 
class  Copy
 
class  Cos
 
class  Cosh
 
class  CustomVJP
 
class  Depends
 
struct  Device
 
class  Divide
 
class  DivMod
 
struct  Dtype
 
class  Equal
 
class  Erf
 
class  ErfInv
 
class  Event
 
class  Exp
 
class  Expm1
 
class  FFT
 
class  Floor
 
class  Full
 
class  Gather
 
class  GatherMM
 
class  GatherQMM
 
class  Greater
 
class  GreaterEqual
 
class  Inverse
 
class  Less
 
class  LessEqual
 
class  Load
 
class  Log
 
class  Log1p
 
class  LogAddExp
 
class  LogicalAnd
 
class  LogicalNot
 
class  LogicalOr
 
class  Matmul
 
class  Maximum
 
class  Minimum
 
class  Multiply
 
class  Negative
 
struct  NodeNamer
 
class  NotEqual
 
class  NumberOfElements
 
class  Pad
 
class  Partition
 
class  Power
 
class  Primitive
 
struct  PrintFormatter
 
class  QRF
 
class  QuantizedMatmul
 
class  RandomBits
 
class  Reduce
 
struct  ReductionPlan
 
class  Remainder
 
class  Reshape
 
class  Round
 
class  Scan
 
class  Scatter
 
class  Select
 
class  Sigmoid
 
class  Sign
 
class  Sin
 
class  Sinh
 
class  Slice
 
class  SliceUpdate
 
class  Softmax
 
class  Sort
 
class  Split
 
class  Sqrt
 
class  Square
 
class  StopGradient
 
struct  Stream
 
struct  StreamContext
 
class  Subtract
 
class  SVD
 
class  Tan
 
class  Tanh
 
class  Transpose
 
struct  TypeToDtype
 
class  UnaryPrimitive
 
class  Uniform
 
class  View
 

Typedefs

using deleter_t = std::function<void(allocator::Buffer)>
 
template<typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>
 
using GGUFMetaData
 
using GGUFLoad
 
using SafetensorsLoad
 
using ValueAndGradFn
 
using SimpleValueAndGradFn
 
typedef struct _MLX_Float16 float16_t
 
typedef struct _MLX_BFloat16 bfloat16_t
 
using StreamOrDevice = std::variant<std::monostate, Stream, Device>
 

Enumerations

enum class  CopyType { Scalar , Vector , General , GeneralGeneral }
 
enum  ReductionOpType {
  ContiguousAllReduce , ContiguousReduce , ContiguousStridedReduce , GeneralContiguousReduce ,
  GeneralStridedReduce , GeneralReduce
}
 
enum class  CompileMode { disabled , no_simplify , no_fuse , enabled }
 

Functions

BNNSDataType to_bnns_dtype (Dtype mlx_dtype)
 
void arange (const std::vector< array > &inputs, array &out, double start, double step)
 
bool is_static_cast (const Primitive &p)
 
std::string build_lib_name (const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &tape, const std::unordered_set< uintptr_t > &constant_ids)
 
std::string get_type_string (Dtype d)
 
template<typename T >
void print_float_constant (std::ostream &os, const array &x)
 
template<typename T >
void print_int_constant (std::ostream &os, const array &x)
 
template<typename T >
void print_complex_constant (std::ostream &os, const array &x)
 
void print_constant (std::ostream &os, const array &x)
 
bool is_scalar (const array &x)
 
bool compiled_check_contiguity (const std::vector< array > &inputs, const std::vector< int > &shape)
 
void compiled_allocate_outputs (const std::vector< array > &inputs, std::vector< array > &outputs, const std::vector< array > &inputs_, const std::unordered_set< uintptr_t > &constant_ids_, bool contiguous, bool move_buffers=false)
 
void copy (const array &src, array &dst, CopyType ctype)
 
void copy_inplace (const array &src, array &dst, CopyType ctype)
 
template<typename stride_t >
void copy_inplace (const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)
 
std::tuple< bool, int64_t, std::vector< int64_t > > prepare_slice (const array &in, std::vector< int > &start_indices, std::vector< int > &strides)
 
void shared_buffer_slice (const array &in, const std::vector< size_t > &out_strides, size_t data_offset, array &out)
 
template<typename stride_t >
stride_t elem_to_loc (int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
size_t elem_to_loc (int elem, const array &a)
 
template<typename stride_t >
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims (const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
 
std::tuple< std::vector< int >, std::vector< std::vector< size_t > > > collapse_contiguous_dims (const std::vector< array > &xs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
auto collapse_contiguous_dims (Arrays &&... xs)
 
template<typename stride_t >
auto check_contiguity (const std::vector< int > &shape, const std::vector< stride_t > &strides)
 
void binary_op_gpu (const std::vector< array > &inputs, std::vector< array > &outputs, const std::string op, const Stream &s)
 
void binary_op_gpu (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
void binary_op_gpu_inplace (const std::vector< array > &inputs, std::vector< array > &outputs, const std::string op, const Stream &s)
 
void binary_op_gpu_inplace (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
template<typename stride_t >
void copy_gpu_inplace (const array &in, array &out, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &s)
 
void copy_gpu (const array &src, array &out, CopyType ctype, const Stream &s)
 
void copy_gpu (const array &src, array &out, CopyType ctype)
 
void copy_gpu_inplace (const array &src, array &out, CopyType ctype, const Stream &s)
 
void copy_gpu_inplace (const array &in, array &out, const std::vector< int64_t > &istride, int64_t ioffset, CopyType ctype, const Stream &s)
 
MTL::ComputePipelineState * get_arange_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_unary_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_binary_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_binary_two_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_ternary_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_copy_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_softmax_kernel (metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
 
MTL::ComputePipelineState * get_scan_kernel (metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
 
MTL::ComputePipelineState * get_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
 
MTL::ComputePipelineState * get_mb_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
 
MTL::ComputePipelineState * get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
 
MTL::ComputePipelineState * get_reduce_kernel (metal::Device &d, const std::string &kernel_name, const std::string &op_name, const array &in, const array &out)
 
MTL::ComputePipelineState * get_steel_gemm_fused_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
 
MTL::ComputePipelineState * get_steel_gemm_masked_kernel (metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
 
MTL::ComputePipelineState * get_steel_conv_kernel (metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
 
MTL::ComputePipelineState * get_steel_conv_general_kernel (metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
 
MTL::ComputePipelineState * get_fft_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const int tg_mem_size, const std::string &in_type, const std::string &out_type, int step, bool real, const metal::MTLFCList &func_consts)
 
void steel_matmul_conv_groups (const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, int groups, std::vector< array > &copies)
 
void steel_matmul (const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
 
void all_reduce_dispatch (const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
 
void row_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
 
void strided_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
 
void slice_gpu (const array &in, array &out, std::vector< int > start_indices, std::vector< int > strides, const Stream &s)
 
void concatenate_gpu (const std::vector< array > &inputs, array &out, int axis, const Stream &s)
 
void pad_gpu (const array &in, const array &val, array &out, std::vector< int > axes, std::vector< int > low_pad_size, const Stream &s)
 
void ternary_op_gpu (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
void ternary_op_gpu_inplace (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
void unary_op_gpu (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
void unary_op_gpu_inplace (const std::vector< array > &inputs, array &out, const std::string op, const Stream &s)
 
void disable_compile ()
 Globally disable compilation.
 
void enable_compile ()
 Globally enable compilation.
 
void set_compile_mode (CompileMode mode)
 Set the compiler mode to the given value.
 
const Devicedefault_device ()
 
void set_default_device (const Device &d)
 
bool operator== (const Device &lhs, const Device &rhs)
 
bool operator!= (const Device &lhs, const Device &rhs)
 
bool issubdtype (const Dtype &a, const Dtype &b)
 
bool issubdtype (const Dtype::Category &a, const Dtype &b)
 
bool issubdtype (const Dtype &a, const Dtype::Category &b)
 
bool issubdtype (const Dtype::Category &a, const Dtype::Category &b)
 
Dtype promote_types (const Dtype &t1, const Dtype &t2)
 
uint8_t size_of (const Dtype &t)
 
Dtype::Kind kindof (const Dtype &t)
 
std::string dtype_to_array_protocol (const Dtype &t)
 
Dtype dtype_from_array_protocol (std::string_view t)
 
void print_graph (std::ostream &os, const std::vector< array > &outputs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void print_graph (std::ostream &os, Arrays &&... outputs)
 
void export_to_dot (std::ostream &os, const std::vector< array > &outputs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void export_to_dot (std::ostream &os, Arrays &&... outputs)
 
void save (std::shared_ptr< io::Writer > out_stream, array a)
 Save array to out stream in .npy format.
 
void save (std::string file, array a)
 Save array to file in .npy format.
 
array load (std::shared_ptr< io::Reader > in_stream, StreamOrDevice s={})
 Load array from reader in .npy format.
 
array load (std::string file, StreamOrDevice s={})
 Load array from file in .npy format.
 
SafetensorsLoad load_safetensors (std::shared_ptr< io::Reader > in_stream, StreamOrDevice s={})
 Load array map from .safetensors file format.
 
SafetensorsLoad load_safetensors (const std::string &file, StreamOrDevice s={})
 
void save_safetensors (std::shared_ptr< io::Writer > in_stream, std::unordered_map< std::string, array >, std::unordered_map< std::string, std::string > metadata={})
 
void save_safetensors (std::string file, std::unordered_map< std::string, array >, std::unordered_map< std::string, std::string > metadata={})
 
GGUFLoad load_gguf (const std::string &file, StreamOrDevice s={})
 Load array map and metadata from .gguf file format.
 
void save_gguf (std::string file, std::unordered_map< std::string, array > array_map, std::unordered_map< std::string, GGUFMetaData > meta_data={})
 
std::vector< int > get_shape (const gguf_tensor &tensor)
 
void gguf_load_quantized (std::unordered_map< std::string, array > &a, const gguf_tensor &tensor)
 
array arange (double start, double stop, double step, Dtype dtype, StreamOrDevice s={})
 A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).
 
array arange (double start, double stop, double step, StreamOrDevice s={})
 
array arange (double start, double stop, Dtype dtype, StreamOrDevice s={})
 
array arange (double start, double stop, StreamOrDevice s={})
 
array arange (double stop, Dtype dtype, StreamOrDevice s={})
 
array arange (double stop, StreamOrDevice s={})
 
array arange (int start, int stop, int step, StreamOrDevice s={})
 
array arange (int start, int stop, StreamOrDevice s={})
 
array arange (int stop, StreamOrDevice s={})
 
array linspace (double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
 A 1D array of num evenly spaced numbers in the range [start, stop]
 
array astype (array a, Dtype dtype, StreamOrDevice s={})
 Convert an array to the given data type.
 
array as_strided (array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
 Create a view of an array with the given shape and strides.
 
array copy (array a, StreamOrDevice s={})
 Copy another array.
 
array full (std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with the given value(s).
 
array full (std::vector< int > shape, array vals, StreamOrDevice s={})
 
template<typename T >
array full (std::vector< int > shape, T val, Dtype dtype, StreamOrDevice s={})
 
template<typename T >
array full (std::vector< int > shape, T val, StreamOrDevice s={})
 
array zeros (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with zeros.
 
array zeros (const std::vector< int > &shape, StreamOrDevice s={})
 
array zeros_like (const array &a, StreamOrDevice s={})
 
array ones (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with ones.
 
array ones (const std::vector< int > &shape, StreamOrDevice s={})
 
array ones_like (const array &a, StreamOrDevice s={})
 
array eye (int n, int m, int k, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.
 
array eye (int n, Dtype dtype, StreamOrDevice s={})
 
array eye (int n, int m, StreamOrDevice s={})
 
array eye (int n, int m, int k, StreamOrDevice s={})
 
array eye (int n, StreamOrDevice s={})
 
array identity (int n, Dtype dtype, StreamOrDevice s={})
 Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
array identity (int n, StreamOrDevice s={})
 
array tri (int n, int m, int k, Dtype type, StreamOrDevice s={})
 
array tri (int n, Dtype type, StreamOrDevice s={})
 
array tril (array x, int k=0, StreamOrDevice s={})
 
array triu (array x, int k=0, StreamOrDevice s={})
 
array reshape (const array &a, std::vector< int > shape, StreamOrDevice s={})
 Reshape an array to the given shape.
 
array flatten (const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
 Flatten the dimensions in the range [start_axis, end_axis] .
 
array flatten (const array &a, StreamOrDevice s={})
 Flatten the array to 1D.
 
array squeeze (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Remove singleton dimensions at the given axes.
 
array squeeze (const array &a, int axis, StreamOrDevice s={})
 Remove singleton dimensions at the given axis.
 
array squeeze (const array &a, StreamOrDevice s={})
 Remove all singleton dimensions.
 
array expand_dims (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Add a singleton dimension at the given axes.
 
array expand_dims (const array &a, int axis, StreamOrDevice s={})
 Add a singleton dimension at the given axis.
 
array slice (const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Slice an array.
 
array slice (const array &a, const std::vector< int > &start, const std::vector< int > &stop, StreamOrDevice s={})
 Slice an array with a stride of 1 in each dimension.
 
array slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Update a slice from the source array.
 
array slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, StreamOrDevice s={})
 Update a slice from the source array with stride 1 in each dimension.
 
std::vector< arraysplit (const array &a, int num_splits, int axis, StreamOrDevice s={})
 Split an array into sub-arrays along a given axis.
 
std::vector< arraysplit (const array &a, int num_splits, StreamOrDevice s={})
 
std::vector< arraysplit (const array &a, const std::vector< int > &indices, int axis, StreamOrDevice s={})
 
std::vector< arraysplit (const array &a, const std::vector< int > &indices, StreamOrDevice s={})
 
std::vector< arraymeshgrid (const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
 A vector of coordinate arrays from coordinate vectors.
 
array clip (const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
 Clip (limit) the values in an array.
 
array concatenate (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Concatenate arrays along a given axis.
 
array concatenate (const std::vector< array > &arrays, StreamOrDevice s={})
 
array stack (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Stack arrays along a new axis.
 
array stack (const std::vector< array > &arrays, StreamOrDevice s={})
 
array repeat (const array &arr, int repeats, int axis, StreamOrDevice s={})
 Repeat an array along an axis.
 
array repeat (const array &arr, int repeats, StreamOrDevice s={})
 
array tile (const array &arr, std::vector< int > reps, StreamOrDevice s={})
 
array transpose (const array &a, std::vector< int > axes, StreamOrDevice s={})
 Permutes the dimensions according to the given axes.
 
array transpose (const array &a, std::initializer_list< int > axes, StreamOrDevice s={})
 
array swapaxes (const array &a, int axis1, int axis2, StreamOrDevice s={})
 Swap two axes of an array.
 
array moveaxis (const array &a, int source, int destination, StreamOrDevice s={})
 Move an axis of an array.
 
array pad (const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), StreamOrDevice s={})
 Pad an array with a constant value.
 
array pad (const array &a, const std::vector< std::pair< int, int > > &pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 Pad an array with a constant value along all axes.
 
array pad (const array &a, const std::pair< int, int > &pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 
array pad (const array &a, int pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 
array transpose (const array &a, StreamOrDevice s={})
 Permutes the dimensions in reverse order.
 
array broadcast_to (const array &a, const std::vector< int > &shape, StreamOrDevice s={})
 Broadcast an array to a given shape.
 
std::vector< arraybroadcast_arrays (const std::vector< array > &inputs, StreamOrDevice s={})
 Broadcast a vector of arrays against one another.
 
array equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a == b) element-wise.
 
array operator== (const array &a, const array &b)
 
template<typename T >
array operator== (T a, const array &b)
 
template<typename T >
array operator== (const array &a, T b)
 
array not_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a != b) element-wise.
 
array operator!= (const array &a, const array &b)
 
template<typename T >
array operator!= (T a, const array &b)
 
template<typename T >
array operator!= (const array &a, T b)
 
array greater (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a > b) element-wise.
 
array operator> (const array &a, const array &b)
 
template<typename T >
array operator> (T a, const array &b)
 
template<typename T >
array operator> (const array &a, T b)
 
array greater_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a >= b) element-wise.
 
array operator>= (const array &a, const array &b)
 
template<typename T >
array operator>= (T a, const array &b)
 
template<typename T >
array operator>= (const array &a, T b)
 
array less (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a < b) element-wise.
 
array operator< (const array &a, const array &b)
 
template<typename T >
array operator< (T a, const array &b)
 
template<typename T >
array operator< (const array &a, T b)
 
array less_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a <= b) element-wise.
 
array operator<= (const array &a, const array &b)
 
template<typename T >
array operator<= (T a, const array &b)
 
template<typename T >
array operator<= (const array &a, T b)
 
array array_equal (const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
 True if two arrays have the same shape and elements.
 
array array_equal (const array &a, const array &b, StreamOrDevice s={})
 
array isnan (const array &a, StreamOrDevice s={})
 
array isinf (const array &a, StreamOrDevice s={})
 
array isposinf (const array &a, StreamOrDevice s={})
 
array isneginf (const array &a, StreamOrDevice s={})
 
array where (const array &condition, const array &x, const array &y, StreamOrDevice s={})
 Select from x or y depending on condition.
 
array all (const array &a, bool keepdims, StreamOrDevice s={})
 True if all elements in the array are true (or non-zero).
 
array all (const array &a, StreamOrDevice s={})
 
array allclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 True if the two arrays are equal within the specified tolerance.
 
array isclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
 
array all (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array all (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array any (const array &a, bool keepdims, StreamOrDevice s={})
 True if any elements in the array are true (or non-zero).
 
array any (const array &a, StreamOrDevice s={})
 
array any (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array any (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array sum (const array &a, bool keepdims, StreamOrDevice s={})
 Sums the elements of an array.
 
array sum (const array &a, StreamOrDevice s={})
 
array sum (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axes.
 
array sum (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axis.
 
array mean (const array &a, bool keepdims, StreamOrDevice s={})
 Computes the mean of the elements of an array.
 
array mean (const array &a, StreamOrDevice s={})
 
array mean (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axes.
 
array mean (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axis.
 
array var (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array.
 
array var (const array &a, StreamOrDevice s={})
 
array var (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axes.
 
array var (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axis.
 
array std (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array.
 
array std (const array &a, StreamOrDevice s={})
 
array std (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviatoin of the elements of an array along the given axes.
 
array std (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array along the given axis.
 
array prod (const array &a, bool keepdims, StreamOrDevice s={})
 The product of all elements of the array.
 
array prod (const array &a, StreamOrDevice s={})
 
array prod (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axes.
 
array prod (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axis.
 
array max (const array &a, bool keepdims, StreamOrDevice s={})
 The maximum of all elements of the array.
 
array max (const array &a, StreamOrDevice s={})
 
array max (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axes.
 
array max (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axis.
 
array min (const array &a, bool keepdims, StreamOrDevice s={})
 The minimum of all elements of the array.
 
array min (const array &a, StreamOrDevice s={})
 
array min (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axes.
 
array min (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axis.
 
array argmin (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the minimum value in the array.
 
array argmin (const array &a, StreamOrDevice s={})
 
array argmin (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the minimum values along a given axis.
 
array argmax (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the maximum value in the array.
 
array argmax (const array &a, StreamOrDevice s={})
 
array argmax (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the maximum values along a given axis.
 
array sort (const array &a, StreamOrDevice s={})
 Returns a sorted copy of the flattened array.
 
array sort (const array &a, int axis, StreamOrDevice s={})
 Returns a sorted copy of the array along a given axis.
 
array argsort (const array &a, StreamOrDevice s={})
 Returns indices that sort the flattened array.
 
array argsort (const array &a, int axis, StreamOrDevice s={})
 Returns indices that sort the array along a given axis.
 
array partition (const array &a, int kth, StreamOrDevice s={})
 Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
 
array partition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.
 
array argpartition (const array &a, int kth, StreamOrDevice s={})
 Returns indices that partition the flattened array such that the smaller kth elements are first.
 
array argpartition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns indices that partition the array along a given axis such that the smaller kth elements are first.
 
array topk (const array &a, int k, StreamOrDevice s={})
 Returns topk elements of the flattened array.
 
array topk (const array &a, int k, int axis, StreamOrDevice s={})
 Returns topk elements of the array along a given axis.
 
array logsumexp (const array &a, bool keepdims, StreamOrDevice s={})
 The logsumexp of all elements of the array.
 
array logsumexp (const array &a, StreamOrDevice s={})
 
array logsumexp (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axes.
 
array logsumexp (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axis.
 
array abs (const array &a, StreamOrDevice s={})
 Absolute value of elements in an array.
 
array negative (const array &a, StreamOrDevice s={})
 Negate an array.
 
array operator- (const array &a)
 
array sign (const array &a, StreamOrDevice s={})
 The sign of the elements in an array.
 
array logical_not (const array &a, StreamOrDevice s={})
 Logical not of an array.
 
array logical_and (const array &a, const array &b, StreamOrDevice s={})
 Logical and of two arrays.
 
array operator&& (const array &a, const array &b)
 
array logical_or (const array &a, const array &b, StreamOrDevice s={})
 Logical or of two arrays.
 
array operator|| (const array &a, const array &b)
 
array reciprocal (const array &a, StreamOrDevice s={})
 The reciprocal (1/x) of the elements in an array.
 
array add (const array &a, const array &b, StreamOrDevice s={})
 Add two arrays.
 
array operator+ (const array &a, const array &b)
 
template<typename T >
array operator+ (T a, const array &b)
 
template<typename T >
array operator+ (const array &a, T b)
 
array subtract (const array &a, const array &b, StreamOrDevice s={})
 Subtract two arrays.
 
array operator- (const array &a, const array &b)
 
template<typename T >
array operator- (T a, const array &b)
 
template<typename T >
array operator- (const array &a, T b)
 
array multiply (const array &a, const array &b, StreamOrDevice s={})
 Multiply two arrays.
 
array operator* (const array &a, const array &b)
 
template<typename T >
array operator* (T a, const array &b)
 
template<typename T >
array operator* (const array &a, T b)
 
array divide (const array &a, const array &b, StreamOrDevice s={})
 Divide two arrays.
 
array operator/ (const array &a, const array &b)
 
array operator/ (double a, const array &b)
 
array operator/ (const array &a, double b)
 
std::vector< arraydivmod (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise quotient and remainder.
 
array floor_divide (const array &a, const array &b, StreamOrDevice s={})
 Compute integer division.
 
array remainder (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise remainder of division.
 
array operator% (const array &a, const array &b)
 
template<typename T >
array operator% (T a, const array &b)
 
template<typename T >
array operator% (const array &a, T b)
 
array maximum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise maximum between two arrays.
 
array minimum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise minimum between two arrays.
 
array floor (const array &a, StreamOrDevice s={})
 Floor the element of an array.
 
array ceil (const array &a, StreamOrDevice s={})
 Ceil the element of an array.
 
array square (const array &a, StreamOrDevice s={})
 Square the elements of an array.
 
array exp (const array &a, StreamOrDevice s={})
 Exponential of the elements of an array.
 
array sin (const array &a, StreamOrDevice s={})
 Sine of the elements of an array.
 
array cos (const array &a, StreamOrDevice s={})
 Cosine of the elements of an array.
 
array tan (const array &a, StreamOrDevice s={})
 Tangent of the elements of an array.
 
array arcsin (const array &a, StreamOrDevice s={})
 Arc Sine of the elements of an array.
 
array arccos (const array &a, StreamOrDevice s={})
 Arc Cosine of the elements of an array.
 
array arctan (const array &a, StreamOrDevice s={})
 Arc Tangent of the elements of an array.
 
array arctan2 (const array &a, const array &b, StreamOrDevice s={})
 Inverse tangent of the ratio of two arrays.
 
array sinh (const array &a, StreamOrDevice s={})
 Hyperbolic Sine of the elements of an array.
 
array cosh (const array &a, StreamOrDevice s={})
 Hyperbolic Cosine of the elements of an array.
 
array tanh (const array &a, StreamOrDevice s={})
 Hyperbolic Tangent of the elements of an array.
 
array arcsinh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Sine of the elements of an array.
 
array arccosh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Cosine of the elements of an array.
 
array arctanh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Tangent of the elements of an array.
 
array degrees (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Radians to Degrees.
 
array radians (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Degrees to Radians.
 
array log (const array &a, StreamOrDevice s={})
 Natural logarithm of the elements of an array.
 
array log2 (const array &a, StreamOrDevice s={})
 Log base 2 of the elements of an array.
 
array log10 (const array &a, StreamOrDevice s={})
 Log base 10 of the elements of an array.
 
array log1p (const array &a, StreamOrDevice s={})
 Natural logarithm of one plus elements in the array: log(1 + a).
 
array logaddexp (const array &a, const array &b, StreamOrDevice s={})
 Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
 
array sigmoid (const array &a, StreamOrDevice s={})
 Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
 
array erf (const array &a, StreamOrDevice s={})
 Computes the error function of the elements of an array.
 
array erfinv (const array &a, StreamOrDevice s={})
 Computes the inverse error function of the elements of an array.
 
array expm1 (const array &a, StreamOrDevice s={})
 Computes the expm1 function of the elements of an array.
 
array stop_gradient (const array &a, StreamOrDevice s={})
 Stop the flow of gradients.
 
array round (const array &a, int decimals, StreamOrDevice s={})
 Round a floating point number.
 
array round (const array &a, StreamOrDevice s={})
 
array matmul (const array &a, const array &b, StreamOrDevice s={})
 Matrix-matrix multiplication.
 
array gather (const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 Gather array entries given indices and slices.
 
array gather (const array &a, const array &indices, int axis, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 
array take (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array slices at the given indices of the specified axis.
 
array take (const array &a, const array &indices, StreamOrDevice s={})
 Take array entries at the given indices treating the array as flattened.
 
array take_along_axis (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array entries given indices along the axis.
 
array scatter (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter updates to the given indices.
 
array scatter (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array scatter_add (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and add updates to given indices.
 
array scatter_add (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array scatter_prod (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and prod updates to given indices.
 
array scatter_prod (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array scatter_max (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and max updates to given linear indices.
 
array scatter_max (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array scatter_min (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and min updates to given linear indices.
 
array scatter_min (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array sqrt (const array &a, StreamOrDevice s={})
 Square root the elements of an array.
 
array rsqrt (const array &a, StreamOrDevice s={})
 Square root and reciprocal the elements of an array.
 
array softmax (const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array softmax (const array &a, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array softmax (const array &a, int axis, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array power (const array &a, const array &b, StreamOrDevice s={})
 Raise elements of a to the power of b element-wise.
 
array cumsum (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative sum of an array.
 
array cumprod (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative product of an array.
 
array cummax (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative max of an array.
 
array cummin (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative min of an array.
 
array conv_general (array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array conv_general (const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array conv1d (const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
 1D convolution with a filter
 
array conv2d (const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
 2D convolution with a filter
 
array conv3d (const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
 3D convolution with a filter
 
array quantized_matmul (const array &x, const array &w, const array &scales, const array &biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantized matmul multiplies x with a quantized matrix w.
 
std::tuple< array, array, arrayquantize (const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantize a matrix along its last axis.
 
array dequantize (const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 Dequantize a matrix produced by quantize()
 
array gather_qmm (const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
 Compute matrix products with matrix-level gather.
 
array tensordot (const array &a, const array &b, const int axis=2, StreamOrDevice s={})
 Returns a contraction of a and b over multiple dimensions.
 
array tensordot (const array &a, const array &b, const std::vector< int > &axes_a, const std::vector< int > &axes_b, StreamOrDevice s={})
 
array outer (const array &a, const array &b, StreamOrDevice s={})
 Compute the outer product of two vectors.
 
array inner (const array &a, const array &b, StreamOrDevice s={})
 Compute the inner product of two vectors.
 
array addmm (array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
 Compute D = beta * C + alpha * (A @ B)
 
array block_masked_mm (array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
 Compute matrix product with block masking.
 
array gather_mm (array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
 Compute matrix product with matrix-level gather.
 
array diagonal (const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
 Extract a diagonal or construct a diagonal array.
 
array diag (const array &a, int k=0, StreamOrDevice s={})
 Extract diagonal from a 2d array or create a diagonal matrix.
 
array trace (const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
 Return the sum along a specified diagonal in the given array.
 
array trace (const array &a, int offset, int axis1, int axis2, StreamOrDevice s={})
 
array trace (const array &a, StreamOrDevice s={})
 
std::vector< arraydepends (const std::vector< array > &inputs, const std::vector< array > &dependencies)
 Implements the identity function but allows injecting dependencies to other arrays.
 
array atleast_1d (const array &a, StreamOrDevice s={})
 convert an array to an atleast ndim array
 
std::vector< arrayatleast_1d (const std::vector< array > &a, StreamOrDevice s={})
 
array atleast_2d (const array &a, StreamOrDevice s={})
 
std::vector< arrayatleast_2d (const std::vector< array > &a, StreamOrDevice s={})
 
array atleast_3d (const array &a, StreamOrDevice s={})
 
std::vector< arrayatleast_3d (const std::vector< array > &a, StreamOrDevice s={})
 
array number_of_elements (const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
 Extract the number of elements along some axes as a scalar array.
 
array conjugate (const array &a, StreamOrDevice s={})
 
array bitwise_and (const array &a, const array &b, StreamOrDevice s={})
 Bitwise and.
 
array operator& (const array &a, const array &b)
 
array bitwise_or (const array &a, const array &b, StreamOrDevice s={})
 Bitwise inclusive or.
 
array operator| (const array &a, const array &b)
 
array bitwise_xor (const array &a, const array &b, StreamOrDevice s={})
 Bitwise exclusive or.
 
array operator^ (const array &a, const array &b)
 
array left_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the left.
 
array operator<< (const array &a, const array &b)
 
array right_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the right.
 
array operator>> (const array &a, const array &b)
 
array view (const array &a, const Dtype &dtype, StreamOrDevice s={})
 
Stream default_stream (Device d)
 Get the default stream for the given device.
 
void set_default_stream (Stream s)
 Make the stream the default for its device.
 
Stream new_stream (Device d)
 Make a new stream on the given device.
 
bool operator== (const Stream &lhs, const Stream &rhs)
 
bool operator!= (const Stream &lhs, const Stream &rhs)
 
void synchronize ()
 
void synchronize (Stream)
 
void async_eval (std::vector< array > outputs)
 
void eval (std::vector< array > outputs)
 
template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void eval (Arrays &&... outputs)
 
std::pair< std::vector< array >, std::vector< array > > vjp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &cotangents)
 Computes the output and vector-Jacobian product (VJP) of a function.
 
std::pair< array, arrayvjp (const std::function< array(const array &)> &fun, const array &primal, const array &cotangent)
 Computes the output and vector-Jacobian product (VJP) of a unary function.
 
std::pair< std::vector< array >, std::vector< array > > jvp (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &primals, const std::vector< array > &tangents)
 Computes the output and Jacobian-vector product (JVP) of a function.
 
std::pair< array, arrayjvp (const std::function< array(const array &)> &fun, const array &primal, const array &tangent)
 Computes the output and Jacobian-vector product (JVP) of a unary function.
 
ValueAndGradFn value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 Returns a function which computes the value and gradient of the input function with respect to a vector of input arrays.
 
ValueAndGradFn value_and_grad (const std::function< std::vector< array >(const std::vector< array > &)> &fun, int argnum=0)
 Returns a function which computes the value and gradient of the input function with respect to a single input array.
 
SimpleValueAndGradFn value_and_grad (const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 
SimpleValueAndGradFn value_and_grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0)
 
std::function< std::vector< array >(const std::vector< array > &)> grad (const std::function< array(const std::vector< array > &)> &fun, int argnum=0)
 Returns a function which computes the gradient of the input function with respect to a single input array.
 
std::function< array(const array &)> grad (const std::function< array(const array &)> &fun)
 Returns a function which computes the gradient of the unary input function.
 
std::function< array(const array &, const array &)> vmap (const std::function< array(const array &, const array &)> &fun, int in_axis_a=0, int in_axis_b=0, int out_axis=0)
 Automatically vectorize a binary function over the requested axes.
 
std::function< std::vector< array >(const std::vector< array > &)> vmap (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &in_axes={}, const std::vector< int > &out_axes={})
 Automatically vectorize a function over the requested axes.
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
float operator+ (_MLX_BFloat16 lhs, float rhs)
 
float operator+ (float lhs, _MLX_BFloat16 rhs)
 
double operator+ (_MLX_BFloat16 lhs, double rhs)
 
double operator+ (double lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, bool rhs)
 
_MLX_BFloat16 operator+ (bool lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int32_t rhs)
 
_MLX_BFloat16 operator+ (int32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint32_t rhs)
 
_MLX_BFloat16 operator+ (uint32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, int64_t rhs)
 
_MLX_BFloat16 operator+ (int64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator+ (_MLX_BFloat16 lhs, uint64_t rhs)
 
_MLX_BFloat16 operator+ (uint64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
float operator- (_MLX_BFloat16 lhs, float rhs)
 
float operator- (float lhs, _MLX_BFloat16 rhs)
 
double operator- (_MLX_BFloat16 lhs, double rhs)
 
double operator- (double lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, bool rhs)
 
_MLX_BFloat16 operator- (bool lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int32_t rhs)
 
_MLX_BFloat16 operator- (int32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint32_t rhs)
 
_MLX_BFloat16 operator- (uint32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, int64_t rhs)
 
_MLX_BFloat16 operator- (int64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs, uint64_t rhs)
 
_MLX_BFloat16 operator- (uint64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
float operator* (_MLX_BFloat16 lhs, float rhs)
 
float operator* (float lhs, _MLX_BFloat16 rhs)
 
double operator* (_MLX_BFloat16 lhs, double rhs)
 
double operator* (double lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, bool rhs)
 
_MLX_BFloat16 operator* (bool lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int32_t rhs)
 
_MLX_BFloat16 operator* (int32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint32_t rhs)
 
_MLX_BFloat16 operator* (uint32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, int64_t rhs)
 
_MLX_BFloat16 operator* (int64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator* (_MLX_BFloat16 lhs, uint64_t rhs)
 
_MLX_BFloat16 operator* (uint64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
float operator/ (_MLX_BFloat16 lhs, float rhs)
 
float operator/ (float lhs, _MLX_BFloat16 rhs)
 
double operator/ (_MLX_BFloat16 lhs, double rhs)
 
double operator/ (double lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, bool rhs)
 
_MLX_BFloat16 operator/ (bool lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int32_t rhs)
 
_MLX_BFloat16 operator/ (int32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint32_t rhs)
 
_MLX_BFloat16 operator/ (uint32_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, int64_t rhs)
 
_MLX_BFloat16 operator/ (int64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator/ (_MLX_BFloat16 lhs, uint64_t rhs)
 
_MLX_BFloat16 operator/ (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, float rhs)
 
bool operator> (float lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, double rhs)
 
bool operator> (double lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator> (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator> (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator> (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator> (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator> (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, float rhs)
 
bool operator< (float lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, double rhs)
 
bool operator< (double lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator< (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator< (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator< (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator< (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator< (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, float rhs)
 
bool operator>= (float lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, double rhs)
 
bool operator>= (double lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator>= (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator>= (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator>= (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator>= (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator>= (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, float rhs)
 
bool operator<= (float lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, double rhs)
 
bool operator<= (double lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator<= (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator<= (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator<= (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator<= (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator<= (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, float rhs)
 
bool operator== (float lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, double rhs)
 
bool operator== (double lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator== (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator== (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator== (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator== (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator== (uint64_t lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, float rhs)
 
bool operator!= (float lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, double rhs)
 
bool operator!= (double lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, int32_t rhs)
 
bool operator!= (int32_t lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, uint32_t rhs)
 
bool operator!= (uint32_t lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, int64_t rhs)
 
bool operator!= (int64_t lhs, _MLX_BFloat16 rhs)
 
bool operator!= (_MLX_BFloat16 lhs, uint64_t rhs)
 
bool operator!= (uint64_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator- (_MLX_BFloat16 lhs)
 
_MLX_BFloat16operator+= (_MLX_BFloat16 &lhs, const float &rhs)
 
float & operator+= (float &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator-= (_MLX_BFloat16 &lhs, const float &rhs)
 
float & operator-= (float &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator*= (_MLX_BFloat16 &lhs, const float &rhs)
 
float & operator*= (float &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator/= (_MLX_BFloat16 &lhs, const float &rhs)
 
float & operator/= (float &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator| (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator| (_MLX_BFloat16 lhs, uint16_t rhs)
 
_MLX_BFloat16 operator| (uint16_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator& (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator& (_MLX_BFloat16 lhs, uint16_t rhs)
 
_MLX_BFloat16 operator& (uint16_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator^ (_MLX_BFloat16 lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16 operator^ (_MLX_BFloat16 lhs, uint16_t rhs)
 
_MLX_BFloat16 operator^ (uint16_t lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator|= (_MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator|= (_MLX_BFloat16 &lhs, uint16_t rhs)
 
_MLX_BFloat16operator&= (_MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator&= (_MLX_BFloat16 &lhs, uint16_t rhs)
 
_MLX_BFloat16operator^= (_MLX_BFloat16 &lhs, _MLX_BFloat16 rhs)
 
_MLX_BFloat16operator^= (_MLX_BFloat16 &lhs, uint16_t rhs)
 
bool operator>= (const complex64_t &a, const complex64_t &b)
 
bool operator> (const complex64_t &a, const complex64_t &b)
 
complex64_t operator% (complex64_t a, complex64_t b)
 
bool operator<= (const complex64_t &a, const complex64_t &b)
 
bool operator< (const complex64_t &a, const complex64_t &b)
 
complex64_t operator- (const complex64_t &v)
 
complex64_t operator+ (const std::complex< float > &x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, const std::complex< float > &y)
 
complex64_t operator+ (const complex64_t &x, const complex64_t &y)
 
complex64_t operator+ (bool x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, bool y)
 
complex64_t operator+ (uint32_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, uint32_t y)
 
complex64_t operator+ (uint64_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, uint64_t y)
 
complex64_t operator+ (int32_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, int32_t y)
 
complex64_t operator+ (int64_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, int64_t y)
 
complex64_t operator+ (float16_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, float16_t y)
 
complex64_t operator+ (bfloat16_t x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, bfloat16_t y)
 
complex64_t operator+ (float x, const complex64_t &y)
 
complex64_t operator+ (const complex64_t &x, float y)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
float operator+ (_MLX_Float16 lhs, float rhs)
 
float operator+ (float lhs, _MLX_Float16 rhs)
 
double operator+ (_MLX_Float16 lhs, double rhs)
 
double operator+ (double lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, bool rhs)
 
_MLX_Float16 operator+ (bool lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, int32_t rhs)
 
_MLX_Float16 operator+ (int32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, uint32_t rhs)
 
_MLX_Float16 operator+ (uint32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, int64_t rhs)
 
_MLX_Float16 operator+ (int64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator+ (_MLX_Float16 lhs, uint64_t rhs)
 
_MLX_Float16 operator+ (uint64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
float operator- (_MLX_Float16 lhs, float rhs)
 
float operator- (float lhs, _MLX_Float16 rhs)
 
double operator- (_MLX_Float16 lhs, double rhs)
 
double operator- (double lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, bool rhs)
 
_MLX_Float16 operator- (bool lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, int32_t rhs)
 
_MLX_Float16 operator- (int32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, uint32_t rhs)
 
_MLX_Float16 operator- (uint32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, int64_t rhs)
 
_MLX_Float16 operator- (int64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs, uint64_t rhs)
 
_MLX_Float16 operator- (uint64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
float operator* (_MLX_Float16 lhs, float rhs)
 
float operator* (float lhs, _MLX_Float16 rhs)
 
double operator* (_MLX_Float16 lhs, double rhs)
 
double operator* (double lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, bool rhs)
 
_MLX_Float16 operator* (bool lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, int32_t rhs)
 
_MLX_Float16 operator* (int32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, uint32_t rhs)
 
_MLX_Float16 operator* (uint32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, int64_t rhs)
 
_MLX_Float16 operator* (int64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator* (_MLX_Float16 lhs, uint64_t rhs)
 
_MLX_Float16 operator* (uint64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
float operator/ (_MLX_Float16 lhs, float rhs)
 
float operator/ (float lhs, _MLX_Float16 rhs)
 
double operator/ (_MLX_Float16 lhs, double rhs)
 
double operator/ (double lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, bool rhs)
 
_MLX_Float16 operator/ (bool lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, int32_t rhs)
 
_MLX_Float16 operator/ (int32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, uint32_t rhs)
 
_MLX_Float16 operator/ (uint32_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, int64_t rhs)
 
_MLX_Float16 operator/ (int64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator/ (_MLX_Float16 lhs, uint64_t rhs)
 
_MLX_Float16 operator/ (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, float rhs)
 
bool operator> (float lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, double rhs)
 
bool operator> (double lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, int32_t rhs)
 
bool operator> (int32_t lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator> (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, int64_t rhs)
 
bool operator> (int64_t lhs, _MLX_Float16 rhs)
 
bool operator> (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator> (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, float rhs)
 
bool operator< (float lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, double rhs)
 
bool operator< (double lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, int32_t rhs)
 
bool operator< (int32_t lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator< (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, int64_t rhs)
 
bool operator< (int64_t lhs, _MLX_Float16 rhs)
 
bool operator< (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator< (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, float rhs)
 
bool operator>= (float lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, double rhs)
 
bool operator>= (double lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, int32_t rhs)
 
bool operator>= (int32_t lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator>= (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, int64_t rhs)
 
bool operator>= (int64_t lhs, _MLX_Float16 rhs)
 
bool operator>= (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator>= (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, float rhs)
 
bool operator<= (float lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, double rhs)
 
bool operator<= (double lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, int32_t rhs)
 
bool operator<= (int32_t lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator<= (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, int64_t rhs)
 
bool operator<= (int64_t lhs, _MLX_Float16 rhs)
 
bool operator<= (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator<= (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, float rhs)
 
bool operator== (float lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, double rhs)
 
bool operator== (double lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, int32_t rhs)
 
bool operator== (int32_t lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator== (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, int64_t rhs)
 
bool operator== (int64_t lhs, _MLX_Float16 rhs)
 
bool operator== (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator== (uint64_t lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, float rhs)
 
bool operator!= (float lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, double rhs)
 
bool operator!= (double lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, int32_t rhs)
 
bool operator!= (int32_t lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, uint32_t rhs)
 
bool operator!= (uint32_t lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, int64_t rhs)
 
bool operator!= (int64_t lhs, _MLX_Float16 rhs)
 
bool operator!= (_MLX_Float16 lhs, uint64_t rhs)
 
bool operator!= (uint64_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator- (_MLX_Float16 lhs)
 
_MLX_Float16operator+= (_MLX_Float16 &lhs, const float &rhs)
 
float & operator+= (float &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator-= (_MLX_Float16 &lhs, const float &rhs)
 
float & operator-= (float &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator*= (_MLX_Float16 &lhs, const float &rhs)
 
float & operator*= (float &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator/= (_MLX_Float16 &lhs, const float &rhs)
 
float & operator/= (float &lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator| (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator| (_MLX_Float16 lhs, uint16_t rhs)
 
_MLX_Float16 operator| (uint16_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator& (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator& (_MLX_Float16 lhs, uint16_t rhs)
 
_MLX_Float16 operator& (uint16_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator^ (_MLX_Float16 lhs, _MLX_Float16 rhs)
 
_MLX_Float16 operator^ (_MLX_Float16 lhs, uint16_t rhs)
 
_MLX_Float16 operator^ (uint16_t lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator|= (_MLX_Float16 &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator|= (_MLX_Float16 &lhs, uint16_t rhs)
 
_MLX_Float16operator&= (_MLX_Float16 &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator&= (_MLX_Float16 &lhs, uint16_t rhs)
 
_MLX_Float16operator^= (_MLX_Float16 &lhs, _MLX_Float16 rhs)
 
_MLX_Float16operator^= (_MLX_Float16 &lhs, uint16_t rhs)
 
float operator+ (float16_t lhs, bfloat16_t rhs)
 
float operator+ (bfloat16_t lhs, float16_t rhs)
 
float operator- (float16_t lhs, bfloat16_t rhs)
 
float operator- (bfloat16_t lhs, float16_t rhs)
 
float operator* (float16_t lhs, bfloat16_t rhs)
 
float operator* (bfloat16_t lhs, float16_t rhs)
 
float operator/ (float16_t lhs, bfloat16_t rhs)
 
float operator/ (bfloat16_t lhs, float16_t rhs)
 
Stream to_stream (StreamOrDevice s)
 
Dtype result_type (const array &a, const array &b)
 The type from promoting the arrays' types with one another.
 
Dtype result_type (const array &a, const array &b, const array &c)
 
Dtype result_type (const std::vector< array > &arrays)
 
std::vector< int > broadcast_shapes (const std::vector< int > &s1, const std::vector< int > &s2)
 
bool is_same_shape (const std::vector< array > &arrays)
 
template<typename T >
int check_shape_dim (const T dim)
 Returns the shape dimension if it's within allowed range.
 
bool is_big_endian ()
 
int normalize_axis (int axis, int ndim)
 Returns the axis normalized to be in the range [0, ndim).
 
std::ostream & operator<< (std::ostream &os, const Device &d)
 
std::ostream & operator<< (std::ostream &os, const Stream &s)
 
std::ostream & operator<< (std::ostream &os, const Dtype &d)
 
std::ostream & operator<< (std::ostream &os, const Dtype::Kind &k)
 
std::ostream & operator<< (std::ostream &os, array a)
 
std::ostream & operator<< (std::ostream &os, const std::vector< int > &v)
 
std::ostream & operator<< (std::ostream &os, const std::vector< size_t > &v)
 
std::ostream & operator<< (std::ostream &os, const std::vector< int64_t > &v)
 
std::ostream & operator<< (std::ostream &os, const complex64_t &v)
 
std::ostream & operator<< (std::ostream &os, const float16_t &v)
 
std::ostream & operator<< (std::ostream &os, const bfloat16_t &v)
 

Variables

template<typename T >
constexpr bool is_array_v
 
template<typename... T>
constexpr bool is_arrays_v = (is_array_v<T> && ...)
 
std::function< std::vector< array >(const std::vector< array > &) compile )(const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false)
 Compile takes a function and returns a compiled function.
 
constexpr Dtype bool_ {Dtype::Val::bool_, sizeof(bool)}
 
constexpr Dtype uint8 {Dtype::Val::uint8, sizeof(uint8_t)}
 
constexpr Dtype uint16 {Dtype::Val::uint16, sizeof(uint16_t)}
 
constexpr Dtype uint32 {Dtype::Val::uint32, sizeof(uint32_t)}
 
constexpr Dtype uint64 {Dtype::Val::uint64, sizeof(uint64_t)}
 
constexpr Dtype int8 {Dtype::Val::int8, sizeof(int8_t)}
 
constexpr Dtype int16 {Dtype::Val::int16, sizeof(int16_t)}
 
constexpr Dtype int32 {Dtype::Val::int32, sizeof(int32_t)}
 
constexpr Dtype int64 {Dtype::Val::int64, sizeof(int64_t)}
 
constexpr Dtype float16 {Dtype::Val::float16, sizeof(uint16_t)}
 
constexpr Dtype float32 {Dtype::Val::float32, sizeof(float)}
 
constexpr Dtype bfloat16 {Dtype::Val::bfloat16, sizeof(uint16_t)}
 
constexpr Dtype complex64 {Dtype::Val::complex64, sizeof(complex64_t)}
 
constexpr Dtype::Category complexfloating
 
constexpr Dtype::Category floating = Dtype::Category::floating
 
constexpr Dtype::Category inexact = Dtype::Category::inexact
 
constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger
 
constexpr Dtype::Category unsignedinteger
 
constexpr Dtype::Category integer = Dtype::Category::integer
 
constexpr Dtype::Category number = Dtype::Category::number
 
constexpr Dtype::Category generic = Dtype::Category::generic
 
std::function< std::pair< array, array >(const array &) value_and_grad )(const std::function< array(const array &)> &fun)
 Returns a function which computes the value and gradient of the unary input function.
 
std::function< std::vector< array >(const std::vector< array > &) grad )(const std::function< array(const std::vector< array > &)> &fun, const std::vector< int > &argnums)
 Returns a function which computes the gradient of the input function with respect to a vector of input arrays.
 
std::function< array(const array &) vmap )(const std::function< array(const array &)> &fun, int in_axis=0, int out_axis=0)
 Automatically vectorize a unary function over the requested axes.
 
std::function< std::vector< array >(const std::vector< array > &) custom_vjp )(std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp)
 Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp.
 
std::function< std::vector< array >(const std::vector< array > &) checkpoint )(std::function< std::vector< array >(const std::vector< array > &)> fun)
 Checkpoint the gradient of a function.
 
template<typename T >
constexpr bool can_convert_to_complex128
 
template<typename T >
constexpr bool can_convert_to_complex64
 
PrintFormatter global_formatter
 

Typedef Documentation

◆ bfloat16_t

◆ deleter_t

using mlx::core::deleter_t = std::function<void(allocator::Buffer)>

◆ enable_for_arrays_t

template<typename... T>
using mlx::core::enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>

◆ float16_t

◆ GGUFLoad

Initial value:
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>>

◆ GGUFMetaData

Initial value:
std::variant<std::monostate, array, std::string, std::vector<std::string>>

◆ SafetensorsLoad

Initial value:
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>

◆ SimpleValueAndGradFn

Initial value:
std::function<std::pair<array, std::vector<array>>(
const std::vector<array>&)>

◆ StreamOrDevice

using mlx::core::StreamOrDevice = std::variant<std::monostate, Stream, Device>

◆ ValueAndGradFn

Initial value:
std::function<std::pair<std::vector<array>, std::vector<array>>(
const std::vector<array>&)>

Enumeration Type Documentation

◆ CompileMode

enum class mlx::core::CompileMode
strong
Enumerator
disabled 
no_simplify 
no_fuse 
enabled 

◆ CopyType

enum class mlx::core::CopyType
strong
Enumerator
Scalar 
Vector 
General 
GeneralGeneral 

◆ ReductionOpType

Enumerator
ContiguousAllReduce 
ContiguousReduce 
ContiguousStridedReduce 
GeneralContiguousReduce 
GeneralStridedReduce 
GeneralReduce 

Function Documentation

◆ all_reduce_dispatch()

void mlx::core::all_reduce_dispatch ( const array & in,
array & out,
const std::string & op_name,
CommandEncoder & compute_encoder,
metal::Device & d,
const Stream & s )

◆ arange()

void mlx::core::arange ( const std::vector< array > & inputs,
array & out,
double start,
double step )

◆ async_eval()

void mlx::core::async_eval ( std::vector< array > outputs)

◆ binary_op_gpu() [1/2]

void mlx::core::binary_op_gpu ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ binary_op_gpu() [2/2]

void mlx::core::binary_op_gpu ( const std::vector< array > & inputs,
std::vector< array > & outputs,
const std::string op,
const Stream & s )

◆ binary_op_gpu_inplace() [1/2]

void mlx::core::binary_op_gpu_inplace ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ binary_op_gpu_inplace() [2/2]

void mlx::core::binary_op_gpu_inplace ( const std::vector< array > & inputs,
std::vector< array > & outputs,
const std::string op,
const Stream & s )

◆ broadcast_shapes()

std::vector< int > mlx::core::broadcast_shapes ( const std::vector< int > & s1,
const std::vector< int > & s2 )

◆ build_lib_name()

std::string mlx::core::build_lib_name ( const std::vector< array > & inputs,
const std::vector< array > & outputs,
const std::vector< array > & tape,
const std::unordered_set< uintptr_t > & constant_ids )

◆ check_contiguity()

template<typename stride_t >
auto mlx::core::check_contiguity ( const std::vector< int > & shape,
const std::vector< stride_t > & strides )
inline

◆ check_shape_dim()

template<typename T >
int mlx::core::check_shape_dim ( const T dim)

Returns the shape dimension if it's within allowed range.

◆ collapse_contiguous_dims() [1/3]

template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
auto mlx::core::collapse_contiguous_dims ( Arrays &&... xs)
inline

◆ collapse_contiguous_dims() [2/3]

std::tuple< std::vector< int >, std::vector< std::vector< size_t > > > mlx::core::collapse_contiguous_dims ( const std::vector< array > & xs)
inline

◆ collapse_contiguous_dims() [3/3]

template<typename stride_t >
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > mlx::core::collapse_contiguous_dims ( const std::vector< int > & shape,
const std::vector< std::vector< stride_t > > strides )
inline

◆ compiled_allocate_outputs()

void mlx::core::compiled_allocate_outputs ( const std::vector< array > & inputs,
std::vector< array > & outputs,
const std::vector< array > & inputs_,
const std::unordered_set< uintptr_t > & constant_ids_,
bool contiguous,
bool move_buffers = false )

◆ compiled_check_contiguity()

bool mlx::core::compiled_check_contiguity ( const std::vector< array > & inputs,
const std::vector< int > & shape )

◆ concatenate_gpu()

void mlx::core::concatenate_gpu ( const std::vector< array > & inputs,
array & out,
int axis,
const Stream & s )

◆ copy()

void mlx::core::copy ( const array & src,
array & dst,
CopyType ctype )

◆ copy_gpu() [1/2]

void mlx::core::copy_gpu ( const array & src,
array & out,
CopyType ctype )

◆ copy_gpu() [2/2]

void mlx::core::copy_gpu ( const array & src,
array & out,
CopyType ctype,
const Stream & s )

◆ copy_gpu_inplace() [1/3]

template<typename stride_t >
void mlx::core::copy_gpu_inplace ( const array & in,
array & out,
const std::vector< int > & data_shape,
const std::vector< stride_t > & i_strides,
const std::vector< stride_t > & o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream & s )

◆ copy_gpu_inplace() [2/3]

void mlx::core::copy_gpu_inplace ( const array & in,
array & out,
const std::vector< int64_t > & istride,
int64_t ioffset,
CopyType ctype,
const Stream & s )

◆ copy_gpu_inplace() [3/3]

void mlx::core::copy_gpu_inplace ( const array & src,
array & out,
CopyType ctype,
const Stream & s )

◆ copy_inplace() [1/2]

template<typename stride_t >
void mlx::core::copy_inplace ( const array & src,
array & dst,
const std::vector< int > & data_shape,
const std::vector< stride_t > & i_strides,
const std::vector< stride_t > & o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype )

◆ copy_inplace() [2/2]

void mlx::core::copy_inplace ( const array & src,
array & dst,
CopyType ctype )

◆ default_device()

const Device & mlx::core::default_device ( )

◆ default_stream()

Stream mlx::core::default_stream ( Device d)

Get the default stream for the given device.

◆ disable_compile()

void mlx::core::disable_compile ( )

Globally disable compilation.

Setting the environment variable MLX_DISABLE_COMPILE can also be used to disable compilation.

◆ dtype_from_array_protocol()

Dtype mlx::core::dtype_from_array_protocol ( std::string_view t)

◆ dtype_to_array_protocol()

std::string mlx::core::dtype_to_array_protocol ( const Dtype & t)

◆ elem_to_loc() [1/2]

size_t mlx::core::elem_to_loc ( int elem,
const array & a )
inline

◆ elem_to_loc() [2/2]

template<typename stride_t >
stride_t mlx::core::elem_to_loc ( int elem,
const std::vector< int > & shape,
const std::vector< stride_t > & strides )
inline

◆ enable_compile()

void mlx::core::enable_compile ( )

Globally enable compilation.

This will override the environment variable MLX_DISABLE_COMPILE.

◆ eval() [1/2]

template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void mlx::core::eval ( Arrays &&... outputs)

◆ eval() [2/2]

void mlx::core::eval ( std::vector< array > outputs)

◆ export_to_dot() [1/2]

template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void mlx::core::export_to_dot ( std::ostream & os,
Arrays &&... outputs )

◆ export_to_dot() [2/2]

void mlx::core::export_to_dot ( std::ostream & os,
const std::vector< array > & outputs )

◆ get_arange_kernel()

MTL::ComputePipelineState * mlx::core::get_arange_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out )

◆ get_binary_kernel()

MTL::ComputePipelineState * mlx::core::get_binary_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out )

◆ get_binary_two_kernel()

MTL::ComputePipelineState * mlx::core::get_binary_two_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out )

◆ get_copy_kernel()

MTL::ComputePipelineState * mlx::core::get_copy_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out )

◆ get_fft_kernel()

MTL::ComputePipelineState * mlx::core::get_fft_kernel ( metal::Device & d,
const std::string & kernel_name,
const std::string & hash_name,
const int tg_mem_size,
const std::string & in_type,
const std::string & out_type,
int step,
bool real,
const metal::MTLFCList & func_consts )

◆ get_mb_sort_kernel()

MTL::ComputePipelineState * mlx::core::get_mb_sort_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & idx,
int bn,
int tn )

◆ get_reduce_init_kernel()

MTL::ComputePipelineState * mlx::core::get_reduce_init_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out )

◆ get_reduce_kernel()

MTL::ComputePipelineState * mlx::core::get_reduce_kernel ( metal::Device & d,
const std::string & kernel_name,
const std::string & op_name,
const array & in,
const array & out )

◆ get_scan_kernel()

MTL::ComputePipelineState * mlx::core::get_scan_kernel ( metal::Device & d,
const std::string & kernel_name,
bool reverse,
bool inclusive,
const std::string & reduce_type,
const array & in,
const array & out )

◆ get_shape()

std::vector< int > mlx::core::get_shape ( const gguf_tensor & tensor)

◆ get_softmax_kernel()

MTL::ComputePipelineState * mlx::core::get_softmax_kernel ( metal::Device & d,
const std::string & kernel_name,
bool precise,
const array & out )

◆ get_sort_kernel()

MTL::ComputePipelineState * mlx::core::get_sort_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out,
int bn,
int tn )

◆ get_steel_conv_general_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_conv_general_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out,
int bm,
int bn,
int bk,
int wm,
int wn )

◆ get_steel_conv_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_conv_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter )

◆ get_steel_gemm_fused_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_gemm_fused_kernel ( metal::Device & d,
const std::string & kernel_name,
const std::string & hash_name,
const metal::MTLFCList & func_consts,
const array & out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn )

◆ get_steel_gemm_masked_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_gemm_masked_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out,
const std::optional< array > & mask_out,
const std::optional< array > & mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned )

◆ get_steel_gemm_splitk_accum_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_gemm_splitk_accum_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out,
bool axbpy )

◆ get_steel_gemm_splitk_kernel()

MTL::ComputePipelineState * mlx::core::get_steel_gemm_splitk_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & in,
const array & out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned )

◆ get_ternary_kernel()

MTL::ComputePipelineState * mlx::core::get_ternary_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out )

◆ get_type_string()

std::string mlx::core::get_type_string ( Dtype d)

◆ get_unary_kernel()

MTL::ComputePipelineState * mlx::core::get_unary_kernel ( metal::Device & d,
const std::string & kernel_name,
const array & out )

◆ gguf_load_quantized()

void mlx::core::gguf_load_quantized ( std::unordered_map< std::string, array > & a,
const gguf_tensor & tensor )

◆ grad() [1/2]

std::function< array(const array &)> mlx::core::grad ( const std::function< array(const array &)> & fun)
inline

Returns a function which computes the gradient of the unary input function.

◆ grad() [2/2]

std::function< std::vector< array >(const std::vector< array > &)> mlx::core::grad ( const std::function< array(const std::vector< array > &)> & fun,
int argnum = 0 )
inline

Returns a function which computes the gradient of the input function with respect to a single input array.

The function being differentiated takes a vector of arrays and returns an array. The optional argnum index specifies which the argument to compute the gradient with respect to and defaults to 0.

◆ is_big_endian()

bool mlx::core::is_big_endian ( )
inline

◆ is_same_shape()

bool mlx::core::is_same_shape ( const std::vector< array > & arrays)

◆ is_scalar()

bool mlx::core::is_scalar ( const array & x)
inline

◆ is_static_cast()

bool mlx::core::is_static_cast ( const Primitive & p)
inline

◆ issubdtype() [1/4]

bool mlx::core::issubdtype ( const Dtype & a,
const Dtype & b )

◆ issubdtype() [2/4]

bool mlx::core::issubdtype ( const Dtype & a,
const Dtype::Category & b )

◆ issubdtype() [3/4]

bool mlx::core::issubdtype ( const Dtype::Category & a,
const Dtype & b )

◆ issubdtype() [4/4]

bool mlx::core::issubdtype ( const Dtype::Category & a,
const Dtype::Category & b )

◆ jvp() [1/2]

std::pair< array, array > mlx::core::jvp ( const std::function< array(const array &)> & fun,
const array & primal,
const array & tangent )

Computes the output and Jacobian-vector product (JVP) of a unary function.

◆ jvp() [2/2]

std::pair< std::vector< array >, std::vector< array > > mlx::core::jvp ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
const std::vector< array > & primals,
const std::vector< array > & tangents )

Computes the output and Jacobian-vector product (JVP) of a function.

Computes the Jacobian-vector product of the Jacobian of the function evaluated at the primals with the vector of tangents. Returns a pair of vectors of output arrays and JVP arrays.

◆ kindof()

Dtype::Kind mlx::core::kindof ( const Dtype & t)

◆ load() [1/2]

array mlx::core::load ( std::shared_ptr< io::Reader > in_stream,
StreamOrDevice s = {} )

Load array from reader in .npy format.

◆ load() [2/2]

array mlx::core::load ( std::string file,
StreamOrDevice s = {} )

Load array from file in .npy format.

◆ load_gguf()

GGUFLoad mlx::core::load_gguf ( const std::string & file,
StreamOrDevice s = {} )

Load array map and metadata from .gguf file format.

◆ load_safetensors() [1/2]

SafetensorsLoad mlx::core::load_safetensors ( const std::string & file,
StreamOrDevice s = {} )

◆ load_safetensors() [2/2]

SafetensorsLoad mlx::core::load_safetensors ( std::shared_ptr< io::Reader > in_stream,
StreamOrDevice s = {} )

Load array map from .safetensors file format.

◆ new_stream()

Stream mlx::core::new_stream ( Device d)

Make a new stream on the given device.

◆ normalize_axis()

int mlx::core::normalize_axis ( int axis,
int ndim )

Returns the axis normalized to be in the range [0, ndim).

Based on numpy's normalize_axis_index. See https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html

◆ operator!=() [1/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [2/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator!=() [3/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator!=() [4/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator!=() [5/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator!=() [6/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator!=() [7/28]

bool mlx::core::operator!= ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator!=() [8/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [9/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator!=() [10/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator!=() [11/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator!=() [12/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator!=() [13/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator!=() [14/28]

bool mlx::core::operator!= ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator!=() [15/28]

bool mlx::core::operator!= ( const Device & lhs,
const Device & rhs )

◆ operator!=() [16/28]

bool mlx::core::operator!= ( const Stream & lhs,
const Stream & rhs )
inline

◆ operator!=() [17/28]

bool mlx::core::operator!= ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [18/28]

bool mlx::core::operator!= ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [19/28]

bool mlx::core::operator!= ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [20/28]

bool mlx::core::operator!= ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [21/28]

bool mlx::core::operator!= ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [22/28]

bool mlx::core::operator!= ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [23/28]

bool mlx::core::operator!= ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [24/28]

bool mlx::core::operator!= ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [25/28]

bool mlx::core::operator!= ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [26/28]

bool mlx::core::operator!= ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator!=() [27/28]

bool mlx::core::operator!= ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator!=() [28/28]

bool mlx::core::operator!= ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator%()

complex64_t mlx::core::operator% ( complex64_t a,
complex64_t b )
inline

◆ operator&() [1/6]

_MLX_BFloat16 mlx::core::operator& ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator&() [2/6]

_MLX_BFloat16 mlx::core::operator& ( _MLX_BFloat16 lhs,
uint16_t rhs )
inline

◆ operator&() [3/6]

_MLX_Float16 mlx::core::operator& ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator&() [4/6]

_MLX_Float16 mlx::core::operator& ( _MLX_Float16 lhs,
uint16_t rhs )
inline

◆ operator&() [5/6]

_MLX_BFloat16 mlx::core::operator& ( uint16_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator&() [6/6]

_MLX_Float16 mlx::core::operator& ( uint16_t lhs,
_MLX_Float16 rhs )
inline

◆ operator&=() [1/4]

_MLX_BFloat16 & mlx::core::operator&= ( _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator&=() [2/4]

_MLX_BFloat16 & mlx::core::operator&= ( _MLX_BFloat16 & lhs,
uint16_t rhs )
inline

◆ operator&=() [3/4]

_MLX_Float16 & mlx::core::operator&= ( _MLX_Float16 & lhs,
_MLX_Float16 rhs )
inline

◆ operator&=() [4/4]

_MLX_Float16 & mlx::core::operator&= ( _MLX_Float16 & lhs,
uint16_t rhs )
inline

◆ operator*() [1/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [2/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
bool rhs )
inline

◆ operator*() [3/32]

double mlx::core::operator* ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator*() [4/32]

float mlx::core::operator* ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator*() [5/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator*() [6/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator*() [7/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator*() [8/32]

_MLX_BFloat16 mlx::core::operator* ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator*() [9/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [10/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
bool rhs )
inline

◆ operator*() [11/32]

double mlx::core::operator* ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator*() [12/32]

float mlx::core::operator* ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator*() [13/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator*() [14/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator*() [15/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator*() [16/32]

_MLX_Float16 mlx::core::operator* ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator*() [17/32]

float mlx::core::operator* ( bfloat16_t lhs,
float16_t rhs )
inline

◆ operator*() [18/32]

_MLX_BFloat16 mlx::core::operator* ( bool lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [19/32]

_MLX_Float16 mlx::core::operator* ( bool lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [20/32]

double mlx::core::operator* ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [21/32]

double mlx::core::operator* ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [22/32]

float mlx::core::operator* ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [23/32]

float mlx::core::operator* ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [24/32]

float mlx::core::operator* ( float16_t lhs,
bfloat16_t rhs )
inline

◆ operator*() [25/32]

_MLX_BFloat16 mlx::core::operator* ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [26/32]

_MLX_Float16 mlx::core::operator* ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [27/32]

_MLX_BFloat16 mlx::core::operator* ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [28/32]

_MLX_Float16 mlx::core::operator* ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [29/32]

_MLX_BFloat16 mlx::core::operator* ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [30/32]

_MLX_Float16 mlx::core::operator* ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator*() [31/32]

_MLX_BFloat16 mlx::core::operator* ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*() [32/32]

_MLX_Float16 mlx::core::operator* ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator*=() [1/4]

_MLX_BFloat16 & mlx::core::operator*= ( _MLX_BFloat16 & lhs,
const float & rhs )
inline

◆ operator*=() [2/4]

_MLX_Float16 & mlx::core::operator*= ( _MLX_Float16 & lhs,
const float & rhs )
inline

◆ operator*=() [3/4]

float & mlx::core::operator*= ( float & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator*=() [4/4]

float & mlx::core::operator*= ( float & lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [1/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [2/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
bool rhs )
inline

◆ operator+() [3/51]

double mlx::core::operator+ ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator+() [4/51]

float mlx::core::operator+ ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator+() [5/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator+() [6/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator+() [7/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator+() [8/51]

_MLX_BFloat16 mlx::core::operator+ ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator+() [9/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [10/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
bool rhs )
inline

◆ operator+() [11/51]

double mlx::core::operator+ ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator+() [12/51]

float mlx::core::operator+ ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator+() [13/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator+() [14/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator+() [15/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator+() [16/51]

_MLX_Float16 mlx::core::operator+ ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator+() [17/51]

float mlx::core::operator+ ( bfloat16_t lhs,
float16_t rhs )
inline

◆ operator+() [18/51]

complex64_t mlx::core::operator+ ( bfloat16_t x,
const complex64_t & y )
inline

◆ operator+() [19/51]

_MLX_BFloat16 mlx::core::operator+ ( bool lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [20/51]

_MLX_Float16 mlx::core::operator+ ( bool lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [21/51]

complex64_t mlx::core::operator+ ( bool x,
const complex64_t & y )
inline

◆ operator+() [22/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
bfloat16_t y )
inline

◆ operator+() [23/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
bool y )
inline

◆ operator+() [24/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
const complex64_t & y )
inline

◆ operator+() [25/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
const std::complex< float > & y )
inline

◆ operator+() [26/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
float y )
inline

◆ operator+() [27/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
float16_t y )
inline

◆ operator+() [28/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
int32_t y )
inline

◆ operator+() [29/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
int64_t y )
inline

◆ operator+() [30/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
uint32_t y )
inline

◆ operator+() [31/51]

complex64_t mlx::core::operator+ ( const complex64_t & x,
uint64_t y )
inline

◆ operator+() [32/51]

complex64_t mlx::core::operator+ ( const std::complex< float > & x,
const complex64_t & y )
inline

◆ operator+() [33/51]

double mlx::core::operator+ ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [34/51]

double mlx::core::operator+ ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [35/51]

float mlx::core::operator+ ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [36/51]

float mlx::core::operator+ ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [37/51]

complex64_t mlx::core::operator+ ( float x,
const complex64_t & y )
inline

◆ operator+() [38/51]

float mlx::core::operator+ ( float16_t lhs,
bfloat16_t rhs )
inline

◆ operator+() [39/51]

complex64_t mlx::core::operator+ ( float16_t x,
const complex64_t & y )
inline

◆ operator+() [40/51]

_MLX_BFloat16 mlx::core::operator+ ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [41/51]

_MLX_Float16 mlx::core::operator+ ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [42/51]

complex64_t mlx::core::operator+ ( int32_t x,
const complex64_t & y )
inline

◆ operator+() [43/51]

_MLX_BFloat16 mlx::core::operator+ ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [44/51]

_MLX_Float16 mlx::core::operator+ ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [45/51]

complex64_t mlx::core::operator+ ( int64_t x,
const complex64_t & y )
inline

◆ operator+() [46/51]

_MLX_BFloat16 mlx::core::operator+ ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [47/51]

_MLX_Float16 mlx::core::operator+ ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [48/51]

complex64_t mlx::core::operator+ ( uint32_t x,
const complex64_t & y )
inline

◆ operator+() [49/51]

_MLX_BFloat16 mlx::core::operator+ ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+() [50/51]

_MLX_Float16 mlx::core::operator+ ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator+() [51/51]

complex64_t mlx::core::operator+ ( uint64_t x,
const complex64_t & y )
inline

◆ operator+=() [1/4]

_MLX_BFloat16 & mlx::core::operator+= ( _MLX_BFloat16 & lhs,
const float & rhs )
inline

◆ operator+=() [2/4]

_MLX_Float16 & mlx::core::operator+= ( _MLX_Float16 & lhs,
const float & rhs )
inline

◆ operator+=() [3/4]

float & mlx::core::operator+= ( float & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator+=() [4/4]

float & mlx::core::operator+= ( float & lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [1/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs)
inline

◆ operator-() [2/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [3/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
bool rhs )
inline

◆ operator-() [4/35]

double mlx::core::operator- ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator-() [5/35]

float mlx::core::operator- ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator-() [6/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator-() [7/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator-() [8/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator-() [9/35]

_MLX_BFloat16 mlx::core::operator- ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator-() [10/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs)
inline

◆ operator-() [11/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [12/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
bool rhs )
inline

◆ operator-() [13/35]

double mlx::core::operator- ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator-() [14/35]

float mlx::core::operator- ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator-() [15/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator-() [16/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator-() [17/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator-() [18/35]

_MLX_Float16 mlx::core::operator- ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator-() [19/35]

float mlx::core::operator- ( bfloat16_t lhs,
float16_t rhs )
inline

◆ operator-() [20/35]

_MLX_BFloat16 mlx::core::operator- ( bool lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [21/35]

_MLX_Float16 mlx::core::operator- ( bool lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [22/35]

complex64_t mlx::core::operator- ( const complex64_t & v)
inline

◆ operator-() [23/35]

double mlx::core::operator- ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [24/35]

double mlx::core::operator- ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [25/35]

float mlx::core::operator- ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [26/35]

float mlx::core::operator- ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [27/35]

float mlx::core::operator- ( float16_t lhs,
bfloat16_t rhs )
inline

◆ operator-() [28/35]

_MLX_BFloat16 mlx::core::operator- ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [29/35]

_MLX_Float16 mlx::core::operator- ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [30/35]

_MLX_BFloat16 mlx::core::operator- ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [31/35]

_MLX_Float16 mlx::core::operator- ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [32/35]

_MLX_BFloat16 mlx::core::operator- ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [33/35]

_MLX_Float16 mlx::core::operator- ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator-() [34/35]

_MLX_BFloat16 mlx::core::operator- ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-() [35/35]

_MLX_Float16 mlx::core::operator- ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator-=() [1/4]

_MLX_BFloat16 & mlx::core::operator-= ( _MLX_BFloat16 & lhs,
const float & rhs )
inline

◆ operator-=() [2/4]

_MLX_Float16 & mlx::core::operator-= ( _MLX_Float16 & lhs,
const float & rhs )
inline

◆ operator-=() [3/4]

float & mlx::core::operator-= ( float & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator-=() [4/4]

float & mlx::core::operator-= ( float & lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [1/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [2/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
bool rhs )
inline

◆ operator/() [3/32]

double mlx::core::operator/ ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator/() [4/32]

float mlx::core::operator/ ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator/() [5/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator/() [6/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator/() [7/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator/() [8/32]

_MLX_BFloat16 mlx::core::operator/ ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator/() [9/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [10/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
bool rhs )
inline

◆ operator/() [11/32]

double mlx::core::operator/ ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator/() [12/32]

float mlx::core::operator/ ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator/() [13/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator/() [14/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator/() [15/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator/() [16/32]

_MLX_Float16 mlx::core::operator/ ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator/() [17/32]

float mlx::core::operator/ ( bfloat16_t lhs,
float16_t rhs )
inline

◆ operator/() [18/32]

_MLX_BFloat16 mlx::core::operator/ ( bool lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [19/32]

_MLX_Float16 mlx::core::operator/ ( bool lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [20/32]

double mlx::core::operator/ ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [21/32]

double mlx::core::operator/ ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [22/32]

float mlx::core::operator/ ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [23/32]

float mlx::core::operator/ ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [24/32]

float mlx::core::operator/ ( float16_t lhs,
bfloat16_t rhs )
inline

◆ operator/() [25/32]

_MLX_BFloat16 mlx::core::operator/ ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [26/32]

_MLX_Float16 mlx::core::operator/ ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [27/32]

_MLX_BFloat16 mlx::core::operator/ ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [28/32]

_MLX_Float16 mlx::core::operator/ ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [29/32]

_MLX_BFloat16 mlx::core::operator/ ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [30/32]

_MLX_Float16 mlx::core::operator/ ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator/() [31/32]

_MLX_BFloat16 mlx::core::operator/ ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/() [32/32]

_MLX_Float16 mlx::core::operator/ ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator/=() [1/4]

_MLX_BFloat16 & mlx::core::operator/= ( _MLX_BFloat16 & lhs,
const float & rhs )
inline

◆ operator/=() [2/4]

_MLX_Float16 & mlx::core::operator/= ( _MLX_Float16 & lhs,
const float & rhs )
inline

◆ operator/=() [3/4]

float & mlx::core::operator/= ( float & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator/=() [4/4]

float & mlx::core::operator/= ( float & lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [1/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [2/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator<() [3/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator<() [4/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator<() [5/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator<() [6/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator<() [7/27]

bool mlx::core::operator< ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator<() [8/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [9/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator<() [10/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator<() [11/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator<() [12/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator<() [13/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator<() [14/27]

bool mlx::core::operator< ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator<() [15/27]

bool mlx::core::operator< ( const complex64_t & a,
const complex64_t & b )
inline

◆ operator<() [16/27]

bool mlx::core::operator< ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [17/27]

bool mlx::core::operator< ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [18/27]

bool mlx::core::operator< ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [19/27]

bool mlx::core::operator< ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [20/27]

bool mlx::core::operator< ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [21/27]

bool mlx::core::operator< ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [22/27]

bool mlx::core::operator< ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [23/27]

bool mlx::core::operator< ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [24/27]

bool mlx::core::operator< ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [25/27]

bool mlx::core::operator< ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<() [26/27]

bool mlx::core::operator< ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<() [27/27]

bool mlx::core::operator< ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<<() [1/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
array a )

◆ operator<<() [2/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const bfloat16_t & v )
inline

◆ operator<<() [3/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const complex64_t & v )
inline

◆ operator<<() [4/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const Device & d )

◆ operator<<() [5/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const Dtype & d )

◆ operator<<() [6/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const Dtype::Kind & k )

◆ operator<<() [7/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const float16_t & v )
inline

◆ operator<<() [8/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const std::vector< int > & v )

◆ operator<<() [9/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const std::vector< int64_t > & v )

◆ operator<<() [10/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const std::vector< size_t > & v )

◆ operator<<() [11/11]

std::ostream & mlx::core::operator<< ( std::ostream & os,
const Stream & s )

◆ operator<=() [1/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [2/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator<=() [3/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator<=() [4/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator<=() [5/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator<=() [6/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator<=() [7/27]

bool mlx::core::operator<= ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator<=() [8/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [9/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator<=() [10/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator<=() [11/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator<=() [12/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator<=() [13/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator<=() [14/27]

bool mlx::core::operator<= ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator<=() [15/27]

bool mlx::core::operator<= ( const complex64_t & a,
const complex64_t & b )
inline

◆ operator<=() [16/27]

bool mlx::core::operator<= ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [17/27]

bool mlx::core::operator<= ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [18/27]

bool mlx::core::operator<= ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [19/27]

bool mlx::core::operator<= ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [20/27]

bool mlx::core::operator<= ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [21/27]

bool mlx::core::operator<= ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [22/27]

bool mlx::core::operator<= ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [23/27]

bool mlx::core::operator<= ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [24/27]

bool mlx::core::operator<= ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [25/27]

bool mlx::core::operator<= ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator<=() [26/27]

bool mlx::core::operator<= ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator<=() [27/27]

bool mlx::core::operator<= ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [1/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [2/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator==() [3/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator==() [4/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator==() [5/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator==() [6/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator==() [7/28]

bool mlx::core::operator== ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator==() [8/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [9/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator==() [10/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator==() [11/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator==() [12/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator==() [13/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator==() [14/28]

bool mlx::core::operator== ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator==() [15/28]

bool mlx::core::operator== ( const Device & lhs,
const Device & rhs )

◆ operator==() [16/28]

bool mlx::core::operator== ( const Stream & lhs,
const Stream & rhs )
inline

◆ operator==() [17/28]

bool mlx::core::operator== ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [18/28]

bool mlx::core::operator== ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [19/28]

bool mlx::core::operator== ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [20/28]

bool mlx::core::operator== ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [21/28]

bool mlx::core::operator== ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [22/28]

bool mlx::core::operator== ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [23/28]

bool mlx::core::operator== ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [24/28]

bool mlx::core::operator== ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [25/28]

bool mlx::core::operator== ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [26/28]

bool mlx::core::operator== ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator==() [27/28]

bool mlx::core::operator== ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator==() [28/28]

bool mlx::core::operator== ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [1/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [2/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator>() [3/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator>() [4/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator>() [5/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator>() [6/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator>() [7/27]

bool mlx::core::operator> ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator>() [8/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [9/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator>() [10/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator>() [11/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator>() [12/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator>() [13/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator>() [14/27]

bool mlx::core::operator> ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator>() [15/27]

bool mlx::core::operator> ( const complex64_t & a,
const complex64_t & b )
inline

◆ operator>() [16/27]

bool mlx::core::operator> ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [17/27]

bool mlx::core::operator> ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [18/27]

bool mlx::core::operator> ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [19/27]

bool mlx::core::operator> ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [20/27]

bool mlx::core::operator> ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [21/27]

bool mlx::core::operator> ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [22/27]

bool mlx::core::operator> ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [23/27]

bool mlx::core::operator> ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [24/27]

bool mlx::core::operator> ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [25/27]

bool mlx::core::operator> ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>() [26/27]

bool mlx::core::operator> ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>() [27/27]

bool mlx::core::operator> ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [1/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [2/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
double rhs )
inline

◆ operator>=() [3/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
float rhs )
inline

◆ operator>=() [4/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
int32_t rhs )
inline

◆ operator>=() [5/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
int64_t rhs )
inline

◆ operator>=() [6/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
uint32_t rhs )
inline

◆ operator>=() [7/27]

bool mlx::core::operator>= ( _MLX_BFloat16 lhs,
uint64_t rhs )
inline

◆ operator>=() [8/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [9/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
double rhs )
inline

◆ operator>=() [10/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
float rhs )
inline

◆ operator>=() [11/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
int32_t rhs )
inline

◆ operator>=() [12/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
int64_t rhs )
inline

◆ operator>=() [13/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
uint32_t rhs )
inline

◆ operator>=() [14/27]

bool mlx::core::operator>= ( _MLX_Float16 lhs,
uint64_t rhs )
inline

◆ operator>=() [15/27]

bool mlx::core::operator>= ( const complex64_t & a,
const complex64_t & b )
inline

◆ operator>=() [16/27]

bool mlx::core::operator>= ( double lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [17/27]

bool mlx::core::operator>= ( double lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [18/27]

bool mlx::core::operator>= ( float lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [19/27]

bool mlx::core::operator>= ( float lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [20/27]

bool mlx::core::operator>= ( int32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [21/27]

bool mlx::core::operator>= ( int32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [22/27]

bool mlx::core::operator>= ( int64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [23/27]

bool mlx::core::operator>= ( int64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [24/27]

bool mlx::core::operator>= ( uint32_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [25/27]

bool mlx::core::operator>= ( uint32_t lhs,
_MLX_Float16 rhs )
inline

◆ operator>=() [26/27]

bool mlx::core::operator>= ( uint64_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator>=() [27/27]

bool mlx::core::operator>= ( uint64_t lhs,
_MLX_Float16 rhs )
inline

◆ operator^() [1/6]

_MLX_BFloat16 mlx::core::operator^ ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator^() [2/6]

_MLX_BFloat16 mlx::core::operator^ ( _MLX_BFloat16 lhs,
uint16_t rhs )
inline

◆ operator^() [3/6]

_MLX_Float16 mlx::core::operator^ ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator^() [4/6]

_MLX_Float16 mlx::core::operator^ ( _MLX_Float16 lhs,
uint16_t rhs )
inline

◆ operator^() [5/6]

_MLX_BFloat16 mlx::core::operator^ ( uint16_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator^() [6/6]

_MLX_Float16 mlx::core::operator^ ( uint16_t lhs,
_MLX_Float16 rhs )
inline

◆ operator^=() [1/4]

_MLX_BFloat16 & mlx::core::operator^= ( _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator^=() [2/4]

_MLX_BFloat16 & mlx::core::operator^= ( _MLX_BFloat16 & lhs,
uint16_t rhs )
inline

◆ operator^=() [3/4]

_MLX_Float16 & mlx::core::operator^= ( _MLX_Float16 & lhs,
_MLX_Float16 rhs )
inline

◆ operator^=() [4/4]

_MLX_Float16 & mlx::core::operator^= ( _MLX_Float16 & lhs,
uint16_t rhs )
inline

◆ operator|() [1/6]

_MLX_BFloat16 mlx::core::operator| ( _MLX_BFloat16 lhs,
_MLX_BFloat16 rhs )
inline

◆ operator|() [2/6]

_MLX_BFloat16 mlx::core::operator| ( _MLX_BFloat16 lhs,
uint16_t rhs )
inline

◆ operator|() [3/6]

_MLX_Float16 mlx::core::operator| ( _MLX_Float16 lhs,
_MLX_Float16 rhs )
inline

◆ operator|() [4/6]

_MLX_Float16 mlx::core::operator| ( _MLX_Float16 lhs,
uint16_t rhs )
inline

◆ operator|() [5/6]

_MLX_BFloat16 mlx::core::operator| ( uint16_t lhs,
_MLX_BFloat16 rhs )
inline

◆ operator|() [6/6]

_MLX_Float16 mlx::core::operator| ( uint16_t lhs,
_MLX_Float16 rhs )
inline

◆ operator|=() [1/4]

_MLX_BFloat16 & mlx::core::operator|= ( _MLX_BFloat16 & lhs,
_MLX_BFloat16 rhs )
inline

◆ operator|=() [2/4]

_MLX_BFloat16 & mlx::core::operator|= ( _MLX_BFloat16 & lhs,
uint16_t rhs )
inline

◆ operator|=() [3/4]

_MLX_Float16 & mlx::core::operator|= ( _MLX_Float16 & lhs,
_MLX_Float16 rhs )
inline

◆ operator|=() [4/4]

_MLX_Float16 & mlx::core::operator|= ( _MLX_Float16 & lhs,
uint16_t rhs )
inline

◆ pad_gpu()

void mlx::core::pad_gpu ( const array & in,
const array & val,
array & out,
std::vector< int > axes,
std::vector< int > low_pad_size,
const Stream & s )

◆ prepare_slice()

std::tuple< bool, int64_t, std::vector< int64_t > > mlx::core::prepare_slice ( const array & in,
std::vector< int > & start_indices,
std::vector< int > & strides )

◆ print_complex_constant()

template<typename T >
void mlx::core::print_complex_constant ( std::ostream & os,
const array & x )

◆ print_constant()

void mlx::core::print_constant ( std::ostream & os,
const array & x )

◆ print_float_constant()

template<typename T >
void mlx::core::print_float_constant ( std::ostream & os,
const array & x )

◆ print_graph() [1/2]

template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void mlx::core::print_graph ( std::ostream & os,
Arrays &&... outputs )

◆ print_graph() [2/2]

void mlx::core::print_graph ( std::ostream & os,
const std::vector< array > & outputs )

◆ print_int_constant()

template<typename T >
void mlx::core::print_int_constant ( std::ostream & os,
const array & x )

◆ promote_types()

Dtype mlx::core::promote_types ( const Dtype & t1,
const Dtype & t2 )

◆ result_type() [1/3]

Dtype mlx::core::result_type ( const array & a,
const array & b )
inline

The type from promoting the arrays' types with one another.

◆ result_type() [2/3]

Dtype mlx::core::result_type ( const array & a,
const array & b,
const array & c )
inline

◆ result_type() [3/3]

Dtype mlx::core::result_type ( const std::vector< array > & arrays)

◆ row_reduce_general_dispatch()

void mlx::core::row_reduce_general_dispatch ( const array & in,
array & out,
const std::string & op_name,
const ReductionPlan & plan,
const std::vector< int > & axes,
CommandEncoder & compute_encoder,
metal::Device & d,
const Stream & s )

◆ save() [1/2]

void mlx::core::save ( std::shared_ptr< io::Writer > out_stream,
array a )

Save array to out stream in .npy format.

◆ save() [2/2]

void mlx::core::save ( std::string file,
array a )

Save array to file in .npy format.

◆ save_gguf()

void mlx::core::save_gguf ( std::string file,
std::unordered_map< std::string, array > array_map,
std::unordered_map< std::string, GGUFMetaData > meta_data = {} )

◆ save_safetensors() [1/2]

void mlx::core::save_safetensors ( std::shared_ptr< io::Writer > in_stream,
std::unordered_map< std::string, array > ,
std::unordered_map< std::string, std::string > metadata = {} )

◆ save_safetensors() [2/2]

void mlx::core::save_safetensors ( std::string file,
std::unordered_map< std::string, array > ,
std::unordered_map< std::string, std::string > metadata = {} )

◆ set_compile_mode()

void mlx::core::set_compile_mode ( CompileMode mode)

Set the compiler mode to the given value.

◆ set_default_device()

void mlx::core::set_default_device ( const Device & d)

◆ set_default_stream()

void mlx::core::set_default_stream ( Stream s)

Make the stream the default for its device.

◆ shared_buffer_slice()

void mlx::core::shared_buffer_slice ( const array & in,
const std::vector< size_t > & out_strides,
size_t data_offset,
array & out )

◆ size_of()

uint8_t mlx::core::size_of ( const Dtype & t)
inline

◆ slice_gpu()

void mlx::core::slice_gpu ( const array & in,
array & out,
std::vector< int > start_indices,
std::vector< int > strides,
const Stream & s )

◆ steel_matmul()

void mlx::core::steel_matmul ( const Stream & s,
metal::Device & d,
const array & a,
const array & b,
array & out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector< array > & copies,
std::vector< int > batch_shape = {},
std::vector< size_t > A_batch_stride = {},
std::vector< size_t > B_batch_stride = {} )

◆ steel_matmul_conv_groups()

void mlx::core::steel_matmul_conv_groups ( const Stream & s,
metal::Device & d,
const array & a,
const array & b,
array & out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector< array > & copies )

◆ strided_reduce_general_dispatch()

void mlx::core::strided_reduce_general_dispatch ( const array & in,
array & out,
const std::string & op_name,
const ReductionPlan & plan,
const std::vector< int > & axes,
CommandEncoder & compute_encoder,
metal::Device & d,
const Stream & s )

◆ synchronize() [1/2]

void mlx::core::synchronize ( )

◆ synchronize() [2/2]

void mlx::core::synchronize ( Stream )

◆ ternary_op_gpu()

void mlx::core::ternary_op_gpu ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ ternary_op_gpu_inplace()

void mlx::core::ternary_op_gpu_inplace ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ to_bnns_dtype()

BNNSDataType mlx::core::to_bnns_dtype ( Dtype mlx_dtype)

◆ to_stream()

Stream mlx::core::to_stream ( StreamOrDevice s)

◆ unary_op_gpu()

void mlx::core::unary_op_gpu ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ unary_op_gpu_inplace()

void mlx::core::unary_op_gpu_inplace ( const std::vector< array > & inputs,
array & out,
const std::string op,
const Stream & s )

◆ value_and_grad() [1/4]

SimpleValueAndGradFn mlx::core::value_and_grad ( const std::function< array(const std::vector< array > &)> & fun,
const std::vector< int > & argnums )
inline

◆ value_and_grad() [2/4]

SimpleValueAndGradFn mlx::core::value_and_grad ( const std::function< array(const std::vector< array > &)> & fun,
int argnum = 0 )
inline

◆ value_and_grad() [3/4]

ValueAndGradFn mlx::core::value_and_grad ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
const std::vector< int > & argnums )

Returns a function which computes the value and gradient of the input function with respect to a vector of input arrays.

◆ value_and_grad() [4/4]

ValueAndGradFn mlx::core::value_and_grad ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
int argnum = 0 )
inline

Returns a function which computes the value and gradient of the input function with respect to a single input array.

◆ vjp() [1/2]

std::pair< array, array > mlx::core::vjp ( const std::function< array(const array &)> & fun,
const array & primal,
const array & cotangent )

Computes the output and vector-Jacobian product (VJP) of a unary function.

◆ vjp() [2/2]

std::pair< std::vector< array >, std::vector< array > > mlx::core::vjp ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
const std::vector< array > & primals,
const std::vector< array > & cotangents )

Computes the output and vector-Jacobian product (VJP) of a function.

Computes the vector-Jacobian product of the vector of cotangents with the Jacobian of the function evaluated at the primals. Returns a pair of vectors of output arrays and VJP arrays.

◆ vmap() [1/2]

std::function< array(const array &, const array &)> mlx::core::vmap ( const std::function< array(const array &, const array &)> & fun,
int in_axis_a = 0,
int in_axis_b = 0,
int out_axis = 0 )

Automatically vectorize a binary function over the requested axes.

◆ vmap() [2/2]

std::function< std::vector< array >(const std::vector< array > &)> mlx::core::vmap ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
const std::vector< int > & in_axes = {},
const std::vector< int > & out_axes = {} )

Automatically vectorize a function over the requested axes.

The input function to vmap takes as an argument a vector of arrays and returns a vector of arrays. Optionally specify the axes to vectorize over with in_axes and out_axes, otherwise a default of 0 is used. Returns a vectorized function with the same signature as the input function.

Variable Documentation

◆ bfloat16

constexpr Dtype mlx::core::bfloat16 {Dtype::Val::bfloat16, sizeof(uint16_t)}
inlineconstexpr

◆ bool_

constexpr Dtype mlx::core::bool_ {Dtype::Val::bool_, sizeof(bool)}
inlineconstexpr

◆ can_convert_to_complex128

template<typename T >
constexpr bool mlx::core::can_convert_to_complex128
inlineconstexpr
Initial value:
=
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>

◆ can_convert_to_complex64

template<typename T >
constexpr bool mlx::core::can_convert_to_complex64
inlineconstexpr
Initial value:
=
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>

◆ checkpoint

std::function< std::vector< array >(const std::vector< array > &) mlx::core::checkpoint) (std::function< std::vector< array >(const std::vector< array > &)> fun) ( std::function< std::vector< array >(const std::vector< array > &)> fun)

Checkpoint the gradient of a function.

Namely, discard all intermediate state and recalculate it when we need to compute the gradient.

◆ compile

std::function< std::vector< array >(const std::vector< array > &) mlx::core::compile) (const std::function< std::vector< array >(const std::vector< array > &)> &fun, bool shapeless=false) ( const std::function< std::vector< array >(const std::vector< array > &)> & fun,
bool shapeless = false )

Compile takes a function and returns a compiled function.

◆ complex64

constexpr Dtype mlx::core::complex64 {Dtype::Val::complex64, sizeof(complex64_t)}
inlineconstexpr

◆ complexfloating

constexpr Dtype::Category mlx::core::complexfloating
inlineconstexpr
Initial value:
=
Dtype::Category::complexfloating

◆ custom_vjp

std::function< std::vector< array >(const std::vector< array > &) mlx::core::custom_vjp) (std::function< std::vector< array >(const std::vector< array > &)> fun, std::function< std::vector< array >( const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp) ( std::function< std::vector< array >(const std::vector< array > &)> fun,
std::function< std::vector< array >(const std::vector< array > &, const std::vector< array > &, const std::vector< array > &)> fun_vjp )

Return the results of calling fun with args but if their vjp is computed it will be computed by fun_vjp.

◆ float16

constexpr Dtype mlx::core::float16 {Dtype::Val::float16, sizeof(uint16_t)}
inlineconstexpr

◆ float32

constexpr Dtype mlx::core::float32 {Dtype::Val::float32, sizeof(float)}
inlineconstexpr

◆ floating

constexpr Dtype::Category mlx::core::floating = Dtype::Category::floating
inlineconstexpr

◆ generic

constexpr Dtype::Category mlx::core::generic = Dtype::Category::generic
inlineconstexpr

◆ global_formatter

PrintFormatter mlx::core::global_formatter
extern

◆ grad

std::function< array(const array &) mlx::core::grad) (const std::function< array(const array &)> &fun) ( const std::function< array(const std::vector< array > &)> & fun,
const std::vector< int > & argnums )
inline

Returns a function which computes the gradient of the input function with respect to a vector of input arrays.

Returns a function which computes the gradient of the unary input function.

Returns a function which computes the gradient of the input function with respect to a single input array.

The function being differentiated takes a vector of arrays and returns an array. The vector of argnums specifies which the arguments to compute the gradient with respect to. At least one argument must be specified.

The function being differentiated takes a vector of arrays and returns an array. The optional argnum index specifies which the argument to compute the gradient with respect to and defaults to 0.

◆ inexact

constexpr Dtype::Category mlx::core::inexact = Dtype::Category::inexact
inlineconstexpr

◆ int16

constexpr Dtype mlx::core::int16 {Dtype::Val::int16, sizeof(int16_t)}
inlineconstexpr

◆ int32

constexpr Dtype mlx::core::int32 {Dtype::Val::int32, sizeof(int32_t)}
inlineconstexpr

◆ int64

constexpr Dtype mlx::core::int64 {Dtype::Val::int64, sizeof(int64_t)}
inlineconstexpr

◆ int8

constexpr Dtype mlx::core::int8 {Dtype::Val::int8, sizeof(int8_t)}
inlineconstexpr

◆ integer

constexpr Dtype::Category mlx::core::integer = Dtype::Category::integer
inlineconstexpr

◆ is_array_v

template<typename T >
constexpr bool mlx::core::is_array_v
inlineconstexpr
Initial value:
=
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>

◆ is_arrays_v

template<typename... T>
constexpr bool mlx::core::is_arrays_v = (is_array_v<T> && ...)
inlineconstexpr

◆ number

constexpr Dtype::Category mlx::core::number = Dtype::Category::number
inlineconstexpr

◆ signedinteger

constexpr Dtype::Category mlx::core::signedinteger = Dtype::Category::signedinteger
inlineconstexpr

◆ uint16

constexpr Dtype mlx::core::uint16 {Dtype::Val::uint16, sizeof(uint16_t)}
inlineconstexpr

◆ uint32

constexpr Dtype mlx::core::uint32 {Dtype::Val::uint32, sizeof(uint32_t)}
inlineconstexpr

◆ uint64

constexpr Dtype mlx::core::uint64 {Dtype::Val::uint64, sizeof(uint64_t)}
inlineconstexpr

◆ uint8

constexpr Dtype mlx::core::uint8 {Dtype::Val::uint8, sizeof(uint8_t)}
inlineconstexpr

◆ unsignedinteger

constexpr Dtype::Category mlx::core::unsignedinteger
inlineconstexpr
Initial value:
=
Dtype::Category::unsignedinteger

◆ value_and_grad

std::function< std::pair< array, array >(const array &) mlx::core::value_and_grad) (const std::function< array(const array &)> &fun) ( const std::function< array(const array &)> & fun)
inline

Returns a function which computes the value and gradient of the unary input function.

◆ vmap

std::function< std::vector< array >(const std::vector< array > &) mlx::core::vmap) (const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< int > &in_axes={}, const std::vector< int > &out_axes={}) ( const std::function< array(const array &)> & fun,
int in_axis = 0,
int out_axis = 0 )

Automatically vectorize a unary function over the requested axes.

Automatically vectorize a function over the requested axes.

Automatically vectorize a binary function over the requested axes.

The input function to vmap takes as an argument a vector of arrays and returns a vector of arrays. Optionally specify the axes to vectorize over with in_axes and out_axes, otherwise a default of 0 is used. Returns a vectorized function with the same signature as the input function.