MLX
Loading...
Searching...
No Matches
Functions
Core array operations

Functions

array mlx::core::arange (double start, double stop, double step, Dtype dtype, StreamOrDevice s={})
 A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).
 
array mlx::core::arange (double start, double stop, double step, StreamOrDevice s={})
 
array mlx::core::arange (double start, double stop, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::arange (double start, double stop, StreamOrDevice s={})
 
array mlx::core::arange (double stop, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::arange (double stop, StreamOrDevice s={})
 
array mlx::core::arange (int start, int stop, int step, StreamOrDevice s={})
 
array mlx::core::arange (int start, int stop, StreamOrDevice s={})
 
array mlx::core::arange (int stop, StreamOrDevice s={})
 
array mlx::core::linspace (double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
 A 1D array of num evenly spaced numbers in the range [start, stop]
 
array mlx::core::astype (array a, Dtype dtype, StreamOrDevice s={})
 Convert an array to the given data type.
 
array mlx::core::as_strided (array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
 Create a view of an array with the given shape and strides.
 
array mlx::core::copy (array a, StreamOrDevice s={})
 Copy another array.
 
array mlx::core::full (std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with the given value(s).
 
array mlx::core::full (std::vector< int > shape, array vals, StreamOrDevice s={})
 
template<typename T >
array mlx::core::full (std::vector< int > shape, T val, Dtype dtype, StreamOrDevice s={})
 
template<typename T >
array mlx::core::full (std::vector< int > shape, T val, StreamOrDevice s={})
 
array mlx::core::zeros (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with zeros.
 
array mlx::core::zeros (const std::vector< int > &shape, StreamOrDevice s={})
 
array mlx::core::zeros_like (const array &a, StreamOrDevice s={})
 
array mlx::core::ones (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with ones.
 
array mlx::core::ones (const std::vector< int > &shape, StreamOrDevice s={})
 
array mlx::core::ones_like (const array &a, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, int k, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.
 
array mlx::core::eye (int n, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, int k, StreamOrDevice s={})
 
array mlx::core::eye (int n, StreamOrDevice s={})
 
array mlx::core::identity (int n, Dtype dtype, StreamOrDevice s={})
 Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
array mlx::core::identity (int n, StreamOrDevice s={})
 
array mlx::core::tri (int n, int m, int k, Dtype type, StreamOrDevice s={})
 
array mlx::core::tri (int n, Dtype type, StreamOrDevice s={})
 
array mlx::core::tril (array x, int k=0, StreamOrDevice s={})
 
array mlx::core::triu (array x, int k=0, StreamOrDevice s={})
 
array mlx::core::reshape (const array &a, std::vector< int > shape, StreamOrDevice s={})
 Reshape an array to the given shape.
 
array mlx::core::flatten (const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
 Flatten the dimensions in the range [start_axis, end_axis] .
 
array mlx::core::flatten (const array &a, StreamOrDevice s={})
 Flatten the array to 1D.
 
array mlx::core::squeeze (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Remove singleton dimensions at the given axes.
 
array mlx::core::squeeze (const array &a, int axis, StreamOrDevice s={})
 Remove singleton dimensions at the given axis.
 
array mlx::core::squeeze (const array &a, StreamOrDevice s={})
 Remove all singleton dimensions.
 
array mlx::core::expand_dims (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Add a singleton dimension at the given axes.
 
array mlx::core::expand_dims (const array &a, int axis, StreamOrDevice s={})
 Add a singleton dimension at the given axis.
 
array mlx::core::slice (const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Slice an array.
 
array mlx::core::slice (const array &a, const std::vector< int > &start, const std::vector< int > &stop, StreamOrDevice s={})
 Slice an array with a stride of 1 in each dimension.
 
array mlx::core::slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Update a slice from the source array.
 
array mlx::core::slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, StreamOrDevice s={})
 Update a slice from the source array with stride 1 in each dimension.
 
std::vector< arraymlx::core::split (const array &a, int num_splits, int axis, StreamOrDevice s={})
 Split an array into sub-arrays along a given axis.
 
std::vector< arraymlx::core::split (const array &a, int num_splits, StreamOrDevice s={})
 
std::vector< arraymlx::core::split (const array &a, const std::vector< int > &indices, int axis, StreamOrDevice s={})
 
std::vector< arraymlx::core::split (const array &a, const std::vector< int > &indices, StreamOrDevice s={})
 
std::vector< arraymlx::core::meshgrid (const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
 A vector of coordinate arrays from coordinate vectors.
 
array mlx::core::clip (const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
 Clip (limit) the values in an array.
 
array mlx::core::concatenate (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Concatenate arrays along a given axis.
 
array mlx::core::concatenate (const std::vector< array > &arrays, StreamOrDevice s={})
 
array mlx::core::stack (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Stack arrays along a new axis.
 
array mlx::core::stack (const std::vector< array > &arrays, StreamOrDevice s={})
 
array mlx::core::repeat (const array &arr, int repeats, int axis, StreamOrDevice s={})
 Repeat an array along an axis.
 
array mlx::core::repeat (const array &arr, int repeats, StreamOrDevice s={})
 
array mlx::core::tile (const array &arr, std::vector< int > reps, StreamOrDevice s={})
 
array mlx::core::transpose (const array &a, std::vector< int > axes, StreamOrDevice s={})
 Permutes the dimensions according to the given axes.
 
array mlx::core::transpose (const array &a, std::initializer_list< int > axes, StreamOrDevice s={})
 
array mlx::core::swapaxes (const array &a, int axis1, int axis2, StreamOrDevice s={})
 Swap two axes of an array.
 
array mlx::core::moveaxis (const array &a, int source, int destination, StreamOrDevice s={})
 Move an axis of an array.
 
array mlx::core::pad (const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), StreamOrDevice s={})
 Pad an array with a constant value.
 
array mlx::core::pad (const array &a, const std::vector< std::pair< int, int > > &pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 Pad an array with a constant value along all axes.
 
array mlx::core::pad (const array &a, const std::pair< int, int > &pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 
array mlx::core::pad (const array &a, int pad_width, const array &pad_value=array(0), StreamOrDevice s={})
 
array mlx::core::transpose (const array &a, StreamOrDevice s={})
 Permutes the dimensions in reverse order.
 
array mlx::core::broadcast_to (const array &a, const std::vector< int > &shape, StreamOrDevice s={})
 Broadcast an array to a given shape.
 
std::vector< arraymlx::core::broadcast_arrays (const std::vector< array > &inputs, StreamOrDevice s={})
 Broadcast a vector of arrays against one another.
 
array mlx::core::equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a == b) element-wise.
 
array mlx::core::operator== (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator== (T a, const array &b)
 
template<typename T >
array mlx::core::operator== (const array &a, T b)
 
array mlx::core::not_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a != b) element-wise.
 
array mlx::core::operator!= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator!= (T a, const array &b)
 
template<typename T >
array mlx::core::operator!= (const array &a, T b)
 
array mlx::core::greater (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a > b) element-wise.
 
array mlx::core::operator> (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator> (T a, const array &b)
 
template<typename T >
array mlx::core::operator> (const array &a, T b)
 
array mlx::core::greater_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a >= b) element-wise.
 
array mlx::core::operator>= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator>= (T a, const array &b)
 
template<typename T >
array mlx::core::operator>= (const array &a, T b)
 
array mlx::core::less (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a < b) element-wise.
 
array mlx::core::operator< (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator< (T a, const array &b)
 
template<typename T >
array mlx::core::operator< (const array &a, T b)
 
array mlx::core::less_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a <= b) element-wise.
 
array mlx::core::operator<= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator<= (T a, const array &b)
 
template<typename T >
array mlx::core::operator<= (const array &a, T b)
 
array mlx::core::array_equal (const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
 True if two arrays have the same shape and elements.
 
array mlx::core::array_equal (const array &a, const array &b, StreamOrDevice s={})
 
array mlx::core::isnan (const array &a, StreamOrDevice s={})
 
array mlx::core::isinf (const array &a, StreamOrDevice s={})
 
array mlx::core::isposinf (const array &a, StreamOrDevice s={})
 
array mlx::core::isneginf (const array &a, StreamOrDevice s={})
 
array mlx::core::where (const array &condition, const array &x, const array &y, StreamOrDevice s={})
 Select from x or y depending on condition.
 
array mlx::core::all (const array &a, bool keepdims, StreamOrDevice s={})
 True if all elements in the array are true (or non-zero).
 
array mlx::core::all (const array &a, StreamOrDevice s={})
 
array mlx::core::allclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 True if the two arrays are equal within the specified tolerance.
 
array mlx::core::isclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
 
array mlx::core::all (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array mlx::core::all (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array mlx::core::any (const array &a, bool keepdims, StreamOrDevice s={})
 True if any elements in the array are true (or non-zero).
 
array mlx::core::any (const array &a, StreamOrDevice s={})
 
array mlx::core::any (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array mlx::core::any (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array mlx::core::sum (const array &a, bool keepdims, StreamOrDevice s={})
 Sums the elements of an array.
 
array mlx::core::sum (const array &a, StreamOrDevice s={})
 
array mlx::core::sum (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axes.
 
array mlx::core::sum (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axis.
 
array mlx::core::mean (const array &a, bool keepdims, StreamOrDevice s={})
 Computes the mean of the elements of an array.
 
array mlx::core::mean (const array &a, StreamOrDevice s={})
 
array mlx::core::mean (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axes.
 
array mlx::core::mean (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axis.
 
array mlx::core::var (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array.
 
array mlx::core::var (const array &a, StreamOrDevice s={})
 
array mlx::core::var (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axes.
 
array mlx::core::var (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axis.
 
array mlx::core::std (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array.
 
array mlx::core::std (const array &a, StreamOrDevice s={})
 
array mlx::core::std (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviatoin of the elements of an array along the given axes.
 
array mlx::core::std (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array along the given axis.
 
array mlx::core::prod (const array &a, bool keepdims, StreamOrDevice s={})
 The product of all elements of the array.
 
array mlx::core::prod (const array &a, StreamOrDevice s={})
 
array mlx::core::prod (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axes.
 
array mlx::core::prod (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axis.
 
array mlx::core::max (const array &a, bool keepdims, StreamOrDevice s={})
 The maximum of all elements of the array.
 
array mlx::core::max (const array &a, StreamOrDevice s={})
 
array mlx::core::max (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axes.
 
array mlx::core::max (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axis.
 
array mlx::core::min (const array &a, bool keepdims, StreamOrDevice s={})
 The minimum of all elements of the array.
 
array mlx::core::min (const array &a, StreamOrDevice s={})
 
array mlx::core::min (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axes.
 
array mlx::core::min (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axis.
 
array mlx::core::argmin (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the minimum value in the array.
 
array mlx::core::argmin (const array &a, StreamOrDevice s={})
 
array mlx::core::argmin (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the minimum values along a given axis.
 
array mlx::core::argmax (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the maximum value in the array.
 
array mlx::core::argmax (const array &a, StreamOrDevice s={})
 
array mlx::core::argmax (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the maximum values along a given axis.
 
array mlx::core::sort (const array &a, StreamOrDevice s={})
 Returns a sorted copy of the flattened array.
 
array mlx::core::sort (const array &a, int axis, StreamOrDevice s={})
 Returns a sorted copy of the array along a given axis.
 
array mlx::core::argsort (const array &a, StreamOrDevice s={})
 Returns indices that sort the flattened array.
 
array mlx::core::argsort (const array &a, int axis, StreamOrDevice s={})
 Returns indices that sort the array along a given axis.
 
array mlx::core::partition (const array &a, int kth, StreamOrDevice s={})
 Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
 
array mlx::core::partition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.
 
array mlx::core::argpartition (const array &a, int kth, StreamOrDevice s={})
 Returns indices that partition the flattened array such that the smaller kth elements are first.
 
array mlx::core::argpartition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns indices that partition the array along a given axis such that the smaller kth elements are first.
 
array mlx::core::topk (const array &a, int k, StreamOrDevice s={})
 Returns topk elements of the flattened array.
 
array mlx::core::topk (const array &a, int k, int axis, StreamOrDevice s={})
 Returns topk elements of the array along a given axis.
 
array mlx::core::logsumexp (const array &a, bool keepdims, StreamOrDevice s={})
 The logsumexp of all elements of the array.
 
array mlx::core::logsumexp (const array &a, StreamOrDevice s={})
 
array mlx::core::logsumexp (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axes.
 
array mlx::core::logsumexp (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axis.
 
array mlx::core::abs (const array &a, StreamOrDevice s={})
 Absolute value of elements in an array.
 
array mlx::core::negative (const array &a, StreamOrDevice s={})
 Negate an array.
 
array mlx::core::operator- (const array &a)
 
array mlx::core::sign (const array &a, StreamOrDevice s={})
 The sign of the elements in an array.
 
array mlx::core::logical_not (const array &a, StreamOrDevice s={})
 Logical not of an array.
 
array mlx::core::logical_and (const array &a, const array &b, StreamOrDevice s={})
 Logical and of two arrays.
 
array mlx::core::operator&& (const array &a, const array &b)
 
array mlx::core::logical_or (const array &a, const array &b, StreamOrDevice s={})
 Logical or of two arrays.
 
array mlx::core::operator|| (const array &a, const array &b)
 
array mlx::core::reciprocal (const array &a, StreamOrDevice s={})
 The reciprocal (1/x) of the elements in an array.
 
array mlx::core::add (const array &a, const array &b, StreamOrDevice s={})
 Add two arrays.
 
array mlx::core::operator+ (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator+ (T a, const array &b)
 
template<typename T >
array mlx::core::operator+ (const array &a, T b)
 
array mlx::core::subtract (const array &a, const array &b, StreamOrDevice s={})
 Subtract two arrays.
 
array mlx::core::operator- (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator- (T a, const array &b)
 
template<typename T >
array mlx::core::operator- (const array &a, T b)
 
array mlx::core::multiply (const array &a, const array &b, StreamOrDevice s={})
 Multiply two arrays.
 
array mlx::core::operator* (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator* (T a, const array &b)
 
template<typename T >
array mlx::core::operator* (const array &a, T b)
 
array mlx::core::divide (const array &a, const array &b, StreamOrDevice s={})
 Divide two arrays.
 
array mlx::core::operator/ (const array &a, const array &b)
 
array mlx::core::operator/ (double a, const array &b)
 
array mlx::core::operator/ (const array &a, double b)
 
std::vector< arraymlx::core::divmod (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise quotient and remainder.
 
array mlx::core::floor_divide (const array &a, const array &b, StreamOrDevice s={})
 Compute integer division.
 
array mlx::core::remainder (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise remainder of division.
 
array mlx::core::operator% (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator% (T a, const array &b)
 
template<typename T >
array mlx::core::operator% (const array &a, T b)
 
array mlx::core::maximum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise maximum between two arrays.
 
array mlx::core::minimum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise minimum between two arrays.
 
array mlx::core::floor (const array &a, StreamOrDevice s={})
 Floor the element of an array.
 
array mlx::core::ceil (const array &a, StreamOrDevice s={})
 Ceil the element of an array.
 
array mlx::core::square (const array &a, StreamOrDevice s={})
 Square the elements of an array.
 
array mlx::core::exp (const array &a, StreamOrDevice s={})
 Exponential of the elements of an array.
 
array mlx::core::sin (const array &a, StreamOrDevice s={})
 Sine of the elements of an array.
 
array mlx::core::cos (const array &a, StreamOrDevice s={})
 Cosine of the elements of an array.
 
array mlx::core::tan (const array &a, StreamOrDevice s={})
 Tangent of the elements of an array.
 
array mlx::core::arcsin (const array &a, StreamOrDevice s={})
 Arc Sine of the elements of an array.
 
array mlx::core::arccos (const array &a, StreamOrDevice s={})
 Arc Cosine of the elements of an array.
 
array mlx::core::arctan (const array &a, StreamOrDevice s={})
 Arc Tangent of the elements of an array.
 
array mlx::core::arctan2 (const array &a, const array &b, StreamOrDevice s={})
 Inverse tangent of the ratio of two arrays.
 
array mlx::core::sinh (const array &a, StreamOrDevice s={})
 Hyperbolic Sine of the elements of an array.
 
array mlx::core::cosh (const array &a, StreamOrDevice s={})
 Hyperbolic Cosine of the elements of an array.
 
array mlx::core::tanh (const array &a, StreamOrDevice s={})
 Hyperbolic Tangent of the elements of an array.
 
array mlx::core::arcsinh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Sine of the elements of an array.
 
array mlx::core::arccosh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Cosine of the elements of an array.
 
array mlx::core::arctanh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Tangent of the elements of an array.
 
array mlx::core::degrees (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Radians to Degrees.
 
array mlx::core::radians (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Degrees to Radians.
 
array mlx::core::log (const array &a, StreamOrDevice s={})
 Natural logarithm of the elements of an array.
 
array mlx::core::log2 (const array &a, StreamOrDevice s={})
 Log base 2 of the elements of an array.
 
array mlx::core::log10 (const array &a, StreamOrDevice s={})
 Log base 10 of the elements of an array.
 
array mlx::core::log1p (const array &a, StreamOrDevice s={})
 Natural logarithm of one plus elements in the array: log(1 + a).
 
array mlx::core::logaddexp (const array &a, const array &b, StreamOrDevice s={})
 Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
 
array mlx::core::sigmoid (const array &a, StreamOrDevice s={})
 Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
 
array mlx::core::erf (const array &a, StreamOrDevice s={})
 Computes the error function of the elements of an array.
 
array mlx::core::erfinv (const array &a, StreamOrDevice s={})
 Computes the inverse error function of the elements of an array.
 
array mlx::core::expm1 (const array &a, StreamOrDevice s={})
 Computes the expm1 function of the elements of an array.
 
array mlx::core::stop_gradient (const array &a, StreamOrDevice s={})
 Stop the flow of gradients.
 
array mlx::core::round (const array &a, int decimals, StreamOrDevice s={})
 Round a floating point number.
 
array mlx::core::round (const array &a, StreamOrDevice s={})
 
array mlx::core::matmul (const array &a, const array &b, StreamOrDevice s={})
 Matrix-matrix multiplication.
 
array mlx::core::gather (const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 Gather array entries given indices and slices.
 
array mlx::core::gather (const array &a, const array &indices, int axis, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 
array mlx::core::take (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array slices at the given indices of the specified axis.
 
array mlx::core::take (const array &a, const array &indices, StreamOrDevice s={})
 Take array entries at the given indices treating the array as flattened.
 
array mlx::core::take_along_axis (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array entries given indices along the axis.
 
array mlx::core::scatter (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter updates to given linear indices.
 
array mlx::core::scatter (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_add (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and add updates to given indices.
 
array mlx::core::scatter_add (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_prod (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and prod updates to given indices.
 
array mlx::core::scatter_prod (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_max (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and max updates to given linear indices.
 
array mlx::core::scatter_max (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_min (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and min updates to given linear indices.
 
array mlx::core::scatter_min (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::sqrt (const array &a, StreamOrDevice s={})
 Square root the elements of an array.
 
array mlx::core::rsqrt (const array &a, StreamOrDevice s={})
 Square root and reciprocal the elements of an array.
 
array mlx::core::softmax (const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::softmax (const array &a, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::softmax (const array &a, int axis, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::power (const array &a, const array &b, StreamOrDevice s={})
 Raise elements of a to the power of b element-wise.
 
array mlx::core::cumsum (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative sum of an array.
 
array mlx::core::cumprod (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative product of an array.
 
array mlx::core::cummax (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative max of an array.
 
array mlx::core::cummin (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative min of an array.
 
array mlx::core::conv_general (array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array mlx::core::conv_general (const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array mlx::core::conv1d (const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
 1D convolution with a filter
 
array mlx::core::conv2d (const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
 2D convolution with a filter
 
array mlx::core::quantized_matmul (const array &x, const array &w, const array &scales, const array &biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantized matmul multiplies x with a quantized matrix w.
 
std::tuple< array, array, arraymlx::core::quantize (const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantize a matrix along its last axis.
 
array mlx::core::dequantize (const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 Dequantize a matrix produced by quantize()
 
array mlx::core::tensordot (const array &a, const array &b, const int axis=2, StreamOrDevice s={})
 Returns a contraction of a and b over multiple dimensions.
 
array mlx::core::tensordot (const array &a, const array &b, const std::vector< int > &axes_a, const std::vector< int > &axes_b, StreamOrDevice s={})
 
array mlx::core::outer (const array &a, const array &b, StreamOrDevice s={})
 Compute the outer product of two vectors.
 
array mlx::core::inner (const array &a, const array &b, StreamOrDevice s={})
 Compute the inner product of two vectors.
 
array mlx::core::addmm (array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
 Compute D = beta * C + alpha * (A @ B)
 
array mlx::core::block_masked_mm (array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
 Compute matrix product with block masking.
 
array mlx::core::block_sparse_mm (array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
 Compute matrix product with matrix-level gather.
 
array mlx::core::diagonal (const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
 Extract a diagonal or construct a diagonal array.
 
array mlx::core::diag (const array &a, int k=0, StreamOrDevice s={})
 Extract diagonal from a 2d array or create a diagonal matrix.
 
std::vector< arraymlx::core::depends (const std::vector< array > &inputs, const std::vector< array > &dependencies)
 Implements the identity function but allows injecting dependencies to other arrays.
 
array mlx::core::atleast_1d (const array &a, StreamOrDevice s={})
 convert an array to an atleast ndim array
 
std::vector< arraymlx::core::atleast_1d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::atleast_2d (const array &a, StreamOrDevice s={})
 
std::vector< arraymlx::core::atleast_2d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::atleast_3d (const array &a, StreamOrDevice s={})
 
std::vector< arraymlx::core::atleast_3d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::number_of_elements (const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
 Extract the number of elements along some axes as a scalar array.
 
array mlx::core::conjugate (const array &a, StreamOrDevice s={})
 
array mlx::core::bitwise_and (const array &a, const array &b, StreamOrDevice s={})
 Bitwise and.
 
array mlx::core::operator& (const array &a, const array &b)
 
array mlx::core::bitwise_or (const array &a, const array &b, StreamOrDevice s={})
 Bitwise inclusive or.
 
array mlx::core::operator| (const array &a, const array &b)
 
array mlx::core::bitwise_xor (const array &a, const array &b, StreamOrDevice s={})
 Bitwise exclusive or.
 
array mlx::core::operator^ (const array &a, const array &b)
 
array mlx::core::left_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the left.
 
array mlx::core::operator<< (const array &a, const array &b)
 
array mlx::core::right_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the right.
 
array mlx::core::operator>> (const array &a, const array &b)
 

Detailed Description

Function Documentation

◆ abs()

array mlx::core::abs ( const array & a,
StreamOrDevice s = {} )

Absolute value of elements in an array.

◆ add()

array mlx::core::add ( const array & a,
const array & b,
StreamOrDevice s = {} )

Add two arrays.

◆ addmm()

array mlx::core::addmm ( array c,
array a,
array b,
const float & alpha = 1.f,
const float & beta = 1.f,
StreamOrDevice s = {} )

Compute D = beta * C + alpha * (A @ B)

◆ all() [1/4]

array mlx::core::all ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

True if all elements in the array are true (or non-zero).

◆ all() [2/4]

array mlx::core::all ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axes.

An output value is true if all the corresponding inputs are true.

◆ all() [3/4]

array mlx::core::all ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axis.

An output value is true if all the corresponding inputs are true.

◆ all() [4/4]

array mlx::core::all ( const array & a,
StreamOrDevice s = {} )
inline

◆ allclose()

array mlx::core::allclose ( const array & a,
const array & b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {} )

True if the two arrays are equal within the specified tolerance.

◆ any() [1/4]

array mlx::core::any ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

True if any elements in the array are true (or non-zero).

◆ any() [2/4]

array mlx::core::any ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axes.

An output value is true if any of the corresponding inputs are true.

◆ any() [3/4]

array mlx::core::any ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axis.

An output value is true if any of the corresponding inputs are true.

◆ any() [4/4]

array mlx::core::any ( const array & a,
StreamOrDevice s = {} )
inline

◆ arange() [1/9]

array mlx::core::arange ( double start,
double stop,
double step,
Dtype dtype,
StreamOrDevice s = {} )

A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).

◆ arange() [2/9]

array mlx::core::arange ( double start,
double stop,
double step,
StreamOrDevice s = {} )

◆ arange() [3/9]

array mlx::core::arange ( double start,
double stop,
Dtype dtype,
StreamOrDevice s = {} )

◆ arange() [4/9]

array mlx::core::arange ( double start,
double stop,
StreamOrDevice s = {} )

◆ arange() [5/9]

array mlx::core::arange ( double stop,
Dtype dtype,
StreamOrDevice s = {} )

◆ arange() [6/9]

array mlx::core::arange ( double stop,
StreamOrDevice s = {} )

◆ arange() [7/9]

array mlx::core::arange ( int start,
int stop,
int step,
StreamOrDevice s = {} )

◆ arange() [8/9]

array mlx::core::arange ( int start,
int stop,
StreamOrDevice s = {} )

◆ arange() [9/9]

array mlx::core::arange ( int stop,
StreamOrDevice s = {} )

◆ arccos()

array mlx::core::arccos ( const array & a,
StreamOrDevice s = {} )

Arc Cosine of the elements of an array.

◆ arccosh()

array mlx::core::arccosh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Cosine of the elements of an array.

◆ arcsin()

array mlx::core::arcsin ( const array & a,
StreamOrDevice s = {} )

Arc Sine of the elements of an array.

◆ arcsinh()

array mlx::core::arcsinh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Sine of the elements of an array.

◆ arctan()

array mlx::core::arctan ( const array & a,
StreamOrDevice s = {} )

Arc Tangent of the elements of an array.

◆ arctan2()

array mlx::core::arctan2 ( const array & a,
const array & b,
StreamOrDevice s = {} )

Inverse tangent of the ratio of two arrays.

◆ arctanh()

array mlx::core::arctanh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Tangent of the elements of an array.

◆ argmax() [1/3]

array mlx::core::argmax ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Returns the index of the maximum value in the array.

◆ argmax() [2/3]

array mlx::core::argmax ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Returns the indices of the maximum values along a given axis.

◆ argmax() [3/3]

array mlx::core::argmax ( const array & a,
StreamOrDevice s = {} )
inline

◆ argmin() [1/3]

array mlx::core::argmin ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Returns the index of the minimum value in the array.

◆ argmin() [2/3]

array mlx::core::argmin ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Returns the indices of the minimum values along a given axis.

◆ argmin() [3/3]

array mlx::core::argmin ( const array & a,
StreamOrDevice s = {} )
inline

◆ argpartition() [1/2]

array mlx::core::argpartition ( const array & a,
int kth,
int axis,
StreamOrDevice s = {} )

Returns indices that partition the array along a given axis such that the smaller kth elements are first.

◆ argpartition() [2/2]

array mlx::core::argpartition ( const array & a,
int kth,
StreamOrDevice s = {} )

Returns indices that partition the flattened array such that the smaller kth elements are first.

◆ argsort() [1/2]

array mlx::core::argsort ( const array & a,
int axis,
StreamOrDevice s = {} )

Returns indices that sort the array along a given axis.

◆ argsort() [2/2]

array mlx::core::argsort ( const array & a,
StreamOrDevice s = {} )

Returns indices that sort the flattened array.

◆ array_equal() [1/2]

array mlx::core::array_equal ( const array & a,
const array & b,
bool equal_nan,
StreamOrDevice s = {} )

True if two arrays have the same shape and elements.

◆ array_equal() [2/2]

array mlx::core::array_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )
inline

◆ as_strided()

array mlx::core::as_strided ( array a,
std::vector< int > shape,
std::vector< size_t > strides,
size_t offset,
StreamOrDevice s = {} )

Create a view of an array with the given shape and strides.

◆ astype()

array mlx::core::astype ( array a,
Dtype dtype,
StreamOrDevice s = {} )

Convert an array to the given data type.

◆ atleast_1d() [1/2]

array mlx::core::atleast_1d ( const array & a,
StreamOrDevice s = {} )

convert an array to an atleast ndim array

◆ atleast_1d() [2/2]

std::vector< array > mlx::core::atleast_1d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ atleast_2d() [1/2]

array mlx::core::atleast_2d ( const array & a,
StreamOrDevice s = {} )

◆ atleast_2d() [2/2]

std::vector< array > mlx::core::atleast_2d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ atleast_3d() [1/2]

array mlx::core::atleast_3d ( const array & a,
StreamOrDevice s = {} )

◆ atleast_3d() [2/2]

std::vector< array > mlx::core::atleast_3d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ bitwise_and()

array mlx::core::bitwise_and ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise and.

◆ bitwise_or()

array mlx::core::bitwise_or ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise inclusive or.

◆ bitwise_xor()

array mlx::core::bitwise_xor ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise exclusive or.

◆ block_masked_mm()

array mlx::core::block_masked_mm ( array a,
array b,
int block_size,
std::optional< array > mask_out = std::nullopt,
std::optional< array > mask_lhs = std::nullopt,
std::optional< array > mask_rhs = std::nullopt,
StreamOrDevice s = {} )

Compute matrix product with block masking.

◆ block_sparse_mm()

array mlx::core::block_sparse_mm ( array a,
array b,
std::optional< array > lhs_indices = std::nullopt,
std::optional< array > rhs_indices = std::nullopt,
StreamOrDevice s = {} )

Compute matrix product with matrix-level gather.

◆ broadcast_arrays()

std::vector< array > mlx::core::broadcast_arrays ( const std::vector< array > & inputs,
StreamOrDevice s = {} )

Broadcast a vector of arrays against one another.

◆ broadcast_to()

array mlx::core::broadcast_to ( const array & a,
const std::vector< int > & shape,
StreamOrDevice s = {} )

Broadcast an array to a given shape.

◆ ceil()

array mlx::core::ceil ( const array & a,
StreamOrDevice s = {} )

Ceil the element of an array.

◆ clip()

array mlx::core::clip ( const array & a,
const std::optional< array > & a_min = std::nullopt,
const std::optional< array > & a_max = std::nullopt,
StreamOrDevice s = {} )

Clip (limit) the values in an array.

◆ concatenate() [1/2]

array mlx::core::concatenate ( const std::vector< array > & arrays,
int axis,
StreamOrDevice s = {} )

Concatenate arrays along a given axis.

◆ concatenate() [2/2]

array mlx::core::concatenate ( const std::vector< array > & arrays,
StreamOrDevice s = {} )

◆ conjugate()

array mlx::core::conjugate ( const array & a,
StreamOrDevice s = {} )

◆ conv1d()

array mlx::core::conv1d ( const array & input,
const array & weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {} )

1D convolution with a filter

◆ conv2d()

array mlx::core::conv2d ( const array & input,
const array & weight,
const std::pair< int, int > & stride = {1, 1},
const std::pair< int, int > & padding = {0, 0},
const std::pair< int, int > & dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {} )

2D convolution with a filter

◆ conv_general() [1/2]

array mlx::core::conv_general ( array input,
array weight,
std::vector< int > stride = {},
std::vector< int > padding_lo = {},
std::vector< int > padding_hi = {},
std::vector< int > kernel_dilation = {},
std::vector< int > input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {} )

General convolution with a filter.

◆ conv_general() [2/2]

array mlx::core::conv_general ( const array & input,
const array & weight,
std::vector< int > stride = {},
std::vector< int > padding = {},
std::vector< int > kernel_dilation = {},
std::vector< int > input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {} )
inline

General convolution with a filter.

◆ copy()

array mlx::core::copy ( array a,
StreamOrDevice s = {} )

Copy another array.

◆ cos()

array mlx::core::cos ( const array & a,
StreamOrDevice s = {} )

Cosine of the elements of an array.

◆ cosh()

array mlx::core::cosh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Cosine of the elements of an array.

◆ cummax()

array mlx::core::cummax ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative max of an array.

◆ cummin()

array mlx::core::cummin ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative min of an array.

◆ cumprod()

array mlx::core::cumprod ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative product of an array.

◆ cumsum()

array mlx::core::cumsum ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative sum of an array.

◆ degrees()

array mlx::core::degrees ( const array & a,
StreamOrDevice s = {} )

Convert the elements of an array from Radians to Degrees.

◆ depends()

std::vector< array > mlx::core::depends ( const std::vector< array > & inputs,
const std::vector< array > & dependencies )

Implements the identity function but allows injecting dependencies to other arrays.

This ensures that these other arrays will have been computed when the outputs of this function are computed.

◆ dequantize()

array mlx::core::dequantize ( const array & w,
const array & scales,
const array & biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Dequantize a matrix produced by quantize()

◆ diag()

array mlx::core::diag ( const array & a,
int k = 0,
StreamOrDevice s = {} )

Extract diagonal from a 2d array or create a diagonal matrix.

◆ diagonal()

array mlx::core::diagonal ( const array & a,
int offset = 0,
int axis1 = 0,
int axis2 = 1,
StreamOrDevice s = {} )

Extract a diagonal or construct a diagonal array.

◆ divide()

array mlx::core::divide ( const array & a,
const array & b,
StreamOrDevice s = {} )

Divide two arrays.

◆ divmod()

std::vector< array > mlx::core::divmod ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the element-wise quotient and remainder.

◆ equal()

array mlx::core::equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns the bool array with (a == b) element-wise.

◆ erf()

array mlx::core::erf ( const array & a,
StreamOrDevice s = {} )

Computes the error function of the elements of an array.

◆ erfinv()

array mlx::core::erfinv ( const array & a,
StreamOrDevice s = {} )

Computes the inverse error function of the elements of an array.

◆ exp()

array mlx::core::exp ( const array & a,
StreamOrDevice s = {} )

Exponential of the elements of an array.

◆ expand_dims() [1/2]

array mlx::core::expand_dims ( const array & a,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Add a singleton dimension at the given axes.

◆ expand_dims() [2/2]

array mlx::core::expand_dims ( const array & a,
int axis,
StreamOrDevice s = {} )

Add a singleton dimension at the given axis.

◆ expm1()

array mlx::core::expm1 ( const array & a,
StreamOrDevice s = {} )

Computes the expm1 function of the elements of an array.

◆ eye() [1/5]

array mlx::core::eye ( int n,
Dtype dtype,
StreamOrDevice s = {} )
inline

◆ eye() [2/5]

array mlx::core::eye ( int n,
int m,
int k,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.

◆ eye() [3/5]

array mlx::core::eye ( int n,
int m,
int k,
StreamOrDevice s = {} )
inline

◆ eye() [4/5]

array mlx::core::eye ( int n,
int m,
StreamOrDevice s = {} )
inline

◆ eye() [5/5]

array mlx::core::eye ( int n,
StreamOrDevice s = {} )
inline

◆ flatten() [1/2]

array mlx::core::flatten ( const array & a,
int start_axis,
int end_axis = -1,
StreamOrDevice s = {} )

Flatten the dimensions in the range [start_axis, end_axis] .

◆ flatten() [2/2]

array mlx::core::flatten ( const array & a,
StreamOrDevice s = {} )

Flatten the array to 1D.

◆ floor()

array mlx::core::floor ( const array & a,
StreamOrDevice s = {} )

Floor the element of an array.

◆ floor_divide()

array mlx::core::floor_divide ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute integer division.

Equivalent to doing floor(a / x).

◆ full() [1/4]

array mlx::core::full ( std::vector< int > shape,
array vals,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with the given value(s).

◆ full() [2/4]

array mlx::core::full ( std::vector< int > shape,
array vals,
StreamOrDevice s = {} )

◆ full() [3/4]

template<typename T >
array mlx::core::full ( std::vector< int > shape,
T val,
Dtype dtype,
StreamOrDevice s = {} )

◆ full() [4/4]

template<typename T >
array mlx::core::full ( std::vector< int > shape,
T val,
StreamOrDevice s = {} )

◆ gather() [1/2]

array mlx::core::gather ( const array & a,
const array & indices,
int axis,
const std::vector< int > & slice_sizes,
StreamOrDevice s = {} )
inline

◆ gather() [2/2]

array mlx::core::gather ( const array & a,
const std::vector< array > & indices,
const std::vector< int > & axes,
const std::vector< int > & slice_sizes,
StreamOrDevice s = {} )

Gather array entries given indices and slices.

◆ greater()

array mlx::core::greater ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a > b) element-wise.

◆ greater_equal()

array mlx::core::greater_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a >= b) element-wise.

◆ identity() [1/2]

array mlx::core::identity ( int n,
Dtype dtype,
StreamOrDevice s = {} )

Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.

◆ identity() [2/2]

array mlx::core::identity ( int n,
StreamOrDevice s = {} )
inline

◆ inner()

array mlx::core::inner ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the inner product of two vectors.

◆ isclose()

array mlx::core::isclose ( const array & a,
const array & b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {} )

Returns a boolean array where two arrays are element-wise equal within the specified tolerance.

◆ isinf()

array mlx::core::isinf ( const array & a,
StreamOrDevice s = {} )

◆ isnan()

array mlx::core::isnan ( const array & a,
StreamOrDevice s = {} )

◆ isneginf()

array mlx::core::isneginf ( const array & a,
StreamOrDevice s = {} )

◆ isposinf()

array mlx::core::isposinf ( const array & a,
StreamOrDevice s = {} )

◆ left_shift()

array mlx::core::left_shift ( const array & a,
const array & b,
StreamOrDevice s = {} )

Shift bits to the left.

◆ less()

array mlx::core::less ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a < b) element-wise.

◆ less_equal()

array mlx::core::less_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a <= b) element-wise.

◆ linspace()

array mlx::core::linspace ( double start,
double stop,
int num = 50,
Dtype dtype = float32,
StreamOrDevice s = {} )

A 1D array of num evenly spaced numbers in the range [start, stop]

◆ log()

array mlx::core::log ( const array & a,
StreamOrDevice s = {} )

Natural logarithm of the elements of an array.

◆ log10()

array mlx::core::log10 ( const array & a,
StreamOrDevice s = {} )

Log base 10 of the elements of an array.

◆ log1p()

array mlx::core::log1p ( const array & a,
StreamOrDevice s = {} )

Natural logarithm of one plus elements in the array: log(1 + a).

◆ log2()

array mlx::core::log2 ( const array & a,
StreamOrDevice s = {} )

Log base 2 of the elements of an array.

◆ logaddexp()

array mlx::core::logaddexp ( const array & a,
const array & b,
StreamOrDevice s = {} )

Log-add-exp of one elements in the array: log(exp(a) + exp(b)).

◆ logical_and()

array mlx::core::logical_and ( const array & a,
const array & b,
StreamOrDevice s = {} )

Logical and of two arrays.

◆ logical_not()

array mlx::core::logical_not ( const array & a,
StreamOrDevice s = {} )

Logical not of an array.

◆ logical_or()

array mlx::core::logical_or ( const array & a,
const array & b,
StreamOrDevice s = {} )

Logical or of two arrays.

◆ logsumexp() [1/4]

array mlx::core::logsumexp ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The logsumexp of all elements of the array.

◆ logsumexp() [2/4]

array mlx::core::logsumexp ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The logsumexp of the elements of an array along the given axes.

◆ logsumexp() [3/4]

array mlx::core::logsumexp ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The logsumexp of the elements of an array along the given axis.

◆ logsumexp() [4/4]

array mlx::core::logsumexp ( const array & a,
StreamOrDevice s = {} )
inline

◆ matmul()

array mlx::core::matmul ( const array & a,
const array & b,
StreamOrDevice s = {} )

Matrix-matrix multiplication.

◆ max() [1/4]

array mlx::core::max ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The maximum of all elements of the array.

◆ max() [2/4]

array mlx::core::max ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The maximum of the elements of an array along the given axes.

◆ max() [3/4]

array mlx::core::max ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The maximum of the elements of an array along the given axis.

◆ max() [4/4]

array mlx::core::max ( const array & a,
StreamOrDevice s = {} )
inline

◆ maximum()

array mlx::core::maximum ( const array & a,
const array & b,
StreamOrDevice s = {} )

Element-wise maximum between two arrays.

◆ mean() [1/4]

array mlx::core::mean ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Computes the mean of the elements of an array.

◆ mean() [2/4]

array mlx::core::mean ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Computes the mean of the elements of an array along the given axes.

◆ mean() [3/4]

array mlx::core::mean ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Computes the mean of the elements of an array along the given axis.

◆ mean() [4/4]

array mlx::core::mean ( const array & a,
StreamOrDevice s = {} )
inline

◆ meshgrid()

std::vector< array > mlx::core::meshgrid ( const std::vector< array > & arrays,
bool sparse = false,
std::string indexing = "xy",
StreamOrDevice s = {} )

A vector of coordinate arrays from coordinate vectors.

◆ min() [1/4]

array mlx::core::min ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The minimum of all elements of the array.

◆ min() [2/4]

array mlx::core::min ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The minimum of the elements of an array along the given axes.

◆ min() [3/4]

array mlx::core::min ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The minimum of the elements of an array along the given axis.

◆ min() [4/4]

array mlx::core::min ( const array & a,
StreamOrDevice s = {} )
inline

◆ minimum()

array mlx::core::minimum ( const array & a,
const array & b,
StreamOrDevice s = {} )

Element-wise minimum between two arrays.

◆ moveaxis()

array mlx::core::moveaxis ( const array & a,
int source,
int destination,
StreamOrDevice s = {} )

Move an axis of an array.

◆ multiply()

array mlx::core::multiply ( const array & a,
const array & b,
StreamOrDevice s = {} )

Multiply two arrays.

◆ negative()

array mlx::core::negative ( const array & a,
StreamOrDevice s = {} )

Negate an array.

◆ not_equal()

array mlx::core::not_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns the bool array with (a != b) element-wise.

◆ number_of_elements()

array mlx::core::number_of_elements ( const array & a,
std::vector< int > axes,
bool inverted,
Dtype dtype = int32,
StreamOrDevice s = {} )

Extract the number of elements along some axes as a scalar array.

Used to allow shape dependent shapeless compilation (pun intended).

◆ ones() [1/2]

array mlx::core::ones ( const std::vector< int > & shape,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with ones.

◆ ones() [2/2]

array mlx::core::ones ( const std::vector< int > & shape,
StreamOrDevice s = {} )
inline

◆ ones_like()

array mlx::core::ones_like ( const array & a,
StreamOrDevice s = {} )

◆ operator!=() [1/3]

array mlx::core::operator!= ( const array & a,
const array & b )
inline

◆ operator!=() [2/3]

template<typename T >
array mlx::core::operator!= ( const array & a,
T b )

◆ operator!=() [3/3]

template<typename T >
array mlx::core::operator!= ( T a,
const array & b )

◆ operator%() [1/3]

array mlx::core::operator% ( const array & a,
const array & b )

◆ operator%() [2/3]

template<typename T >
array mlx::core::operator% ( const array & a,
T b )

◆ operator%() [3/3]

template<typename T >
array mlx::core::operator% ( T a,
const array & b )

◆ operator&()

array mlx::core::operator& ( const array & a,
const array & b )

◆ operator&&()

array mlx::core::operator&& ( const array & a,
const array & b )

◆ operator*() [1/3]

array mlx::core::operator* ( const array & a,
const array & b )

◆ operator*() [2/3]

template<typename T >
array mlx::core::operator* ( const array & a,
T b )

◆ operator*() [3/3]

template<typename T >
array mlx::core::operator* ( T a,
const array & b )

◆ operator+() [1/3]

array mlx::core::operator+ ( const array & a,
const array & b )

◆ operator+() [2/3]

template<typename T >
array mlx::core::operator+ ( const array & a,
T b )

◆ operator+() [3/3]

template<typename T >
array mlx::core::operator+ ( T a,
const array & b )

◆ operator-() [1/4]

array mlx::core::operator- ( const array & a)

◆ operator-() [2/4]

array mlx::core::operator- ( const array & a,
const array & b )

◆ operator-() [3/4]

template<typename T >
array mlx::core::operator- ( const array & a,
T b )

◆ operator-() [4/4]

template<typename T >
array mlx::core::operator- ( T a,
const array & b )

◆ operator/() [1/3]

array mlx::core::operator/ ( const array & a,
const array & b )

◆ operator/() [2/3]

array mlx::core::operator/ ( const array & a,
double b )

◆ operator/() [3/3]

array mlx::core::operator/ ( double a,
const array & b )

◆ operator<() [1/3]

array mlx::core::operator< ( const array & a,
const array & b )
inline

◆ operator<() [2/3]

template<typename T >
array mlx::core::operator< ( const array & a,
T b )

◆ operator<() [3/3]

template<typename T >
array mlx::core::operator< ( T a,
const array & b )

◆ operator<<()

array mlx::core::operator<< ( const array & a,
const array & b )

◆ operator<=() [1/3]

array mlx::core::operator<= ( const array & a,
const array & b )
inline

◆ operator<=() [2/3]

template<typename T >
array mlx::core::operator<= ( const array & a,
T b )

◆ operator<=() [3/3]

template<typename T >
array mlx::core::operator<= ( T a,
const array & b )

◆ operator==() [1/3]

array mlx::core::operator== ( const array & a,
const array & b )
inline

◆ operator==() [2/3]

template<typename T >
array mlx::core::operator== ( const array & a,
T b )

◆ operator==() [3/3]

template<typename T >
array mlx::core::operator== ( T a,
const array & b )

◆ operator>() [1/3]

array mlx::core::operator> ( const array & a,
const array & b )
inline

◆ operator>() [2/3]

template<typename T >
array mlx::core::operator> ( const array & a,
T b )

◆ operator>() [3/3]

template<typename T >
array mlx::core::operator> ( T a,
const array & b )

◆ operator>=() [1/3]

array mlx::core::operator>= ( const array & a,
const array & b )
inline

◆ operator>=() [2/3]

template<typename T >
array mlx::core::operator>= ( const array & a,
T b )

◆ operator>=() [3/3]

template<typename T >
array mlx::core::operator>= ( T a,
const array & b )

◆ operator>>()

array mlx::core::operator>> ( const array & a,
const array & b )

◆ operator^()

array mlx::core::operator^ ( const array & a,
const array & b )

◆ operator|()

array mlx::core::operator| ( const array & a,
const array & b )

◆ operator||()

array mlx::core::operator|| ( const array & a,
const array & b )

◆ outer()

array mlx::core::outer ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the outer product of two vectors.

◆ pad() [1/4]

array mlx::core::pad ( const array & a,
const std::pair< int, int > & pad_width,
const array & pad_value = array(0),
StreamOrDevice s = {} )

◆ pad() [2/4]

array mlx::core::pad ( const array & a,
const std::vector< int > & axes,
const std::vector< int > & low_pad_size,
const std::vector< int > & high_pad_size,
const array & pad_value = array(0),
StreamOrDevice s = {} )

Pad an array with a constant value.

◆ pad() [3/4]

array mlx::core::pad ( const array & a,
const std::vector< std::pair< int, int > > & pad_width,
const array & pad_value = array(0),
StreamOrDevice s = {} )

Pad an array with a constant value along all axes.

◆ pad() [4/4]

array mlx::core::pad ( const array & a,
int pad_width,
const array & pad_value = array(0),
StreamOrDevice s = {} )

◆ partition() [1/2]

array mlx::core::partition ( const array & a,
int kth,
int axis,
StreamOrDevice s = {} )

Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.

◆ partition() [2/2]

array mlx::core::partition ( const array & a,
int kth,
StreamOrDevice s = {} )

Returns a partitioned copy of the flattened array such that the smaller kth elements are first.

◆ power()

array mlx::core::power ( const array & a,
const array & b,
StreamOrDevice s = {} )

Raise elements of a to the power of b element-wise.

◆ prod() [1/4]

array mlx::core::prod ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The product of all elements of the array.

◆ prod() [2/4]

array mlx::core::prod ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The product of the elements of an array along the given axes.

◆ prod() [3/4]

array mlx::core::prod ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The product of the elements of an array along the given axis.

◆ prod() [4/4]

array mlx::core::prod ( const array & a,
StreamOrDevice s = {} )
inline

◆ quantize()

std::tuple< array, array, array > mlx::core::quantize ( const array & w,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Quantize a matrix along its last axis.

◆ quantized_matmul()

array mlx::core::quantized_matmul ( const array & x,
const array & w,
const array & scales,
const array & biases,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Quantized matmul multiplies x with a quantized matrix w.

◆ radians()

array mlx::core::radians ( const array & a,
StreamOrDevice s = {} )

Convert the elements of an array from Degrees to Radians.

◆ reciprocal()

array mlx::core::reciprocal ( const array & a,
StreamOrDevice s = {} )

The reciprocal (1/x) of the elements in an array.

◆ remainder()

array mlx::core::remainder ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the element-wise remainder of division.

◆ repeat() [1/2]

array mlx::core::repeat ( const array & arr,
int repeats,
int axis,
StreamOrDevice s = {} )

Repeat an array along an axis.

◆ repeat() [2/2]

array mlx::core::repeat ( const array & arr,
int repeats,
StreamOrDevice s = {} )

◆ reshape()

array mlx::core::reshape ( const array & a,
std::vector< int > shape,
StreamOrDevice s = {} )

Reshape an array to the given shape.

◆ right_shift()

array mlx::core::right_shift ( const array & a,
const array & b,
StreamOrDevice s = {} )

Shift bits to the right.

◆ round() [1/2]

array mlx::core::round ( const array & a,
int decimals,
StreamOrDevice s = {} )

Round a floating point number.

◆ round() [2/2]

array mlx::core::round ( const array & a,
StreamOrDevice s = {} )
inline

◆ rsqrt()

array mlx::core::rsqrt ( const array & a,
StreamOrDevice s = {} )

Square root and reciprocal the elements of an array.

◆ scatter() [1/2]

array mlx::core::scatter ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter() [2/2]

array mlx::core::scatter ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter updates to given linear indices.

◆ scatter_add() [1/2]

array mlx::core::scatter_add ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_add() [2/2]

array mlx::core::scatter_add ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and add updates to given indices.

◆ scatter_max() [1/2]

array mlx::core::scatter_max ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_max() [2/2]

array mlx::core::scatter_max ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and max updates to given linear indices.

◆ scatter_min() [1/2]

array mlx::core::scatter_min ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_min() [2/2]

array mlx::core::scatter_min ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and min updates to given linear indices.

◆ scatter_prod() [1/2]

array mlx::core::scatter_prod ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_prod() [2/2]

array mlx::core::scatter_prod ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and prod updates to given indices.

◆ sigmoid()

array mlx::core::sigmoid ( const array & a,
StreamOrDevice s = {} )

Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).

◆ sign()

array mlx::core::sign ( const array & a,
StreamOrDevice s = {} )

The sign of the elements in an array.

◆ sin()

array mlx::core::sin ( const array & a,
StreamOrDevice s = {} )

Sine of the elements of an array.

◆ sinh()

array mlx::core::sinh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Sine of the elements of an array.

◆ slice() [1/2]

array mlx::core::slice ( const array & a,
const std::vector< int > & start,
const std::vector< int > & stop,
StreamOrDevice s = {} )

Slice an array with a stride of 1 in each dimension.

◆ slice() [2/2]

array mlx::core::slice ( const array & a,
std::vector< int > start,
std::vector< int > stop,
std::vector< int > strides,
StreamOrDevice s = {} )

Slice an array.

◆ slice_update() [1/2]

array mlx::core::slice_update ( const array & src,
const array & update,
std::vector< int > start,
std::vector< int > stop,
std::vector< int > strides,
StreamOrDevice s = {} )

Update a slice from the source array.

◆ slice_update() [2/2]

array mlx::core::slice_update ( const array & src,
const array & update,
std::vector< int > start,
std::vector< int > stop,
StreamOrDevice s = {} )

Update a slice from the source array with stride 1 in each dimension.

◆ softmax() [1/3]

array mlx::core::softmax ( const array & a,
bool precise = false,
StreamOrDevice s = {} )

Softmax of an array.

◆ softmax() [2/3]

array mlx::core::softmax ( const array & a,
const std::vector< int > & axes,
bool precise = false,
StreamOrDevice s = {} )

Softmax of an array.

◆ softmax() [3/3]

array mlx::core::softmax ( const array & a,
int axis,
bool precise = false,
StreamOrDevice s = {} )
inline

Softmax of an array.

◆ sort() [1/2]

array mlx::core::sort ( const array & a,
int axis,
StreamOrDevice s = {} )

Returns a sorted copy of the array along a given axis.

◆ sort() [2/2]

array mlx::core::sort ( const array & a,
StreamOrDevice s = {} )

Returns a sorted copy of the flattened array.

◆ split() [1/4]

std::vector< array > mlx::core::split ( const array & a,
const std::vector< int > & indices,
int axis,
StreamOrDevice s = {} )

◆ split() [2/4]

std::vector< array > mlx::core::split ( const array & a,
const std::vector< int > & indices,
StreamOrDevice s = {} )

◆ split() [3/4]

std::vector< array > mlx::core::split ( const array & a,
int num_splits,
int axis,
StreamOrDevice s = {} )

Split an array into sub-arrays along a given axis.

◆ split() [4/4]

std::vector< array > mlx::core::split ( const array & a,
int num_splits,
StreamOrDevice s = {} )

◆ sqrt()

array mlx::core::sqrt ( const array & a,
StreamOrDevice s = {} )

Square root the elements of an array.

◆ square()

array mlx::core::square ( const array & a,
StreamOrDevice s = {} )

Square the elements of an array.

◆ squeeze() [1/3]

array mlx::core::squeeze ( const array & a,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Remove singleton dimensions at the given axes.

◆ squeeze() [2/3]

array mlx::core::squeeze ( const array & a,
int axis,
StreamOrDevice s = {} )
inline

Remove singleton dimensions at the given axis.

◆ squeeze() [3/3]

array mlx::core::squeeze ( const array & a,
StreamOrDevice s = {} )

Remove all singleton dimensions.

◆ stack() [1/2]

array mlx::core::stack ( const std::vector< array > & arrays,
int axis,
StreamOrDevice s = {} )

Stack arrays along a new axis.

◆ stack() [2/2]

array mlx::core::stack ( const std::vector< array > & arrays,
StreamOrDevice s = {} )

◆ std() [1/4]

array mlx::core::std ( const array & a,
bool keepdims,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviation of the elements of an array.

◆ std() [2/4]

array mlx::core::std ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviatoin of the elements of an array along the given axes.

◆ std() [3/4]

array mlx::core::std ( const array & a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviation of the elements of an array along the given axis.

◆ std() [4/4]

array mlx::core::std ( const array & a,
StreamOrDevice s = {} )
inline

◆ stop_gradient()

array mlx::core::stop_gradient ( const array & a,
StreamOrDevice s = {} )

Stop the flow of gradients.

◆ subtract()

array mlx::core::subtract ( const array & a,
const array & b,
StreamOrDevice s = {} )

Subtract two arrays.

◆ sum() [1/4]

array mlx::core::sum ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Sums the elements of an array.

◆ sum() [2/4]

array mlx::core::sum ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Sums the elements of an array along the given axes.

◆ sum() [3/4]

array mlx::core::sum ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Sums the elements of an array along the given axis.

◆ sum() [4/4]

array mlx::core::sum ( const array & a,
StreamOrDevice s = {} )
inline

◆ swapaxes()

array mlx::core::swapaxes ( const array & a,
int axis1,
int axis2,
StreamOrDevice s = {} )

Swap two axes of an array.

◆ take() [1/2]

array mlx::core::take ( const array & a,
const array & indices,
int axis,
StreamOrDevice s = {} )

Take array slices at the given indices of the specified axis.

◆ take() [2/2]

array mlx::core::take ( const array & a,
const array & indices,
StreamOrDevice s = {} )

Take array entries at the given indices treating the array as flattened.

◆ take_along_axis()

array mlx::core::take_along_axis ( const array & a,
const array & indices,
int axis,
StreamOrDevice s = {} )

Take array entries given indices along the axis.

◆ tan()

array mlx::core::tan ( const array & a,
StreamOrDevice s = {} )

Tangent of the elements of an array.

◆ tanh()

array mlx::core::tanh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Tangent of the elements of an array.

◆ tensordot() [1/2]

array mlx::core::tensordot ( const array & a,
const array & b,
const int axis = 2,
StreamOrDevice s = {} )

Returns a contraction of a and b over multiple dimensions.

◆ tensordot() [2/2]

array mlx::core::tensordot ( const array & a,
const array & b,
const std::vector< int > & axes_a,
const std::vector< int > & axes_b,
StreamOrDevice s = {} )

◆ tile()

array mlx::core::tile ( const array & arr,
std::vector< int > reps,
StreamOrDevice s = {} )

◆ topk() [1/2]

array mlx::core::topk ( const array & a,
int k,
int axis,
StreamOrDevice s = {} )

Returns topk elements of the array along a given axis.

◆ topk() [2/2]

array mlx::core::topk ( const array & a,
int k,
StreamOrDevice s = {} )

Returns topk elements of the flattened array.

◆ transpose() [1/3]

array mlx::core::transpose ( const array & a,
std::initializer_list< int > axes,
StreamOrDevice s = {} )
inline

◆ transpose() [2/3]

array mlx::core::transpose ( const array & a,
std::vector< int > axes,
StreamOrDevice s = {} )

Permutes the dimensions according to the given axes.

◆ transpose() [3/3]

array mlx::core::transpose ( const array & a,
StreamOrDevice s = {} )

Permutes the dimensions in reverse order.

◆ tri() [1/2]

array mlx::core::tri ( int n,
Dtype type,
StreamOrDevice s = {} )
inline

◆ tri() [2/2]

array mlx::core::tri ( int n,
int m,
int k,
Dtype type,
StreamOrDevice s = {} )

◆ tril()

array mlx::core::tril ( array x,
int k = 0,
StreamOrDevice s = {} )

◆ triu()

array mlx::core::triu ( array x,
int k = 0,
StreamOrDevice s = {} )

◆ var() [1/4]

array mlx::core::var ( const array & a,
bool keepdims,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array.

◆ var() [2/4]

array mlx::core::var ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array along the given axes.

◆ var() [3/4]

array mlx::core::var ( const array & a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array along the given axis.

◆ var() [4/4]

array mlx::core::var ( const array & a,
StreamOrDevice s = {} )
inline

◆ where()

array mlx::core::where ( const array & condition,
const array & x,
const array & y,
StreamOrDevice s = {} )

Select from x or y depending on condition.

◆ zeros() [1/2]

array mlx::core::zeros ( const std::vector< int > & shape,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with zeros.

◆ zeros() [2/2]

array mlx::core::zeros ( const std::vector< int > & shape,
StreamOrDevice s = {} )
inline

◆ zeros_like()

array mlx::core::zeros_like ( const array & a,
StreamOrDevice s = {} )