MLX
Loading...
Searching...
No Matches
Functions
Core array operations

Functions

array mlx::core::arange (double start, double stop, double step, Dtype dtype, StreamOrDevice s={})
 A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).
 
array mlx::core::arange (double start, double stop, double step, StreamOrDevice s={})
 
array mlx::core::arange (double start, double stop, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::arange (double start, double stop, StreamOrDevice s={})
 
array mlx::core::arange (double stop, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::arange (double stop, StreamOrDevice s={})
 
array mlx::core::arange (int start, int stop, int step, StreamOrDevice s={})
 
array mlx::core::arange (int start, int stop, StreamOrDevice s={})
 
array mlx::core::arange (int stop, StreamOrDevice s={})
 
array mlx::core::linspace (double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
 A 1D array of num evenly spaced numbers in the range [start, stop]
 
array mlx::core::astype (array a, Dtype dtype, StreamOrDevice s={})
 Convert an array to the given data type.
 
array mlx::core::as_strided (array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
 Create a view of an array with the given shape and strides.
 
array mlx::core::copy (array a, StreamOrDevice s={})
 Copy another array.
 
array mlx::core::full (std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with the given value(s).
 
array mlx::core::full (std::vector< int > shape, array vals, StreamOrDevice s={})
 
template<typename T >
array mlx::core::full (std::vector< int > shape, T val, Dtype dtype, StreamOrDevice s={})
 
template<typename T >
array mlx::core::full (std::vector< int > shape, T val, StreamOrDevice s={})
 
array mlx::core::zeros (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with zeros.
 
array mlx::core::zeros (const std::vector< int > &shape, StreamOrDevice s={})
 
array mlx::core::zeros_like (const array &a, StreamOrDevice s={})
 
array mlx::core::ones (const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape with ones.
 
array mlx::core::ones (const std::vector< int > &shape, StreamOrDevice s={})
 
array mlx::core::ones_like (const array &a, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, int k, Dtype dtype, StreamOrDevice s={})
 Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.
 
array mlx::core::eye (int n, Dtype dtype, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, StreamOrDevice s={})
 
array mlx::core::eye (int n, int m, int k, StreamOrDevice s={})
 
array mlx::core::eye (int n, StreamOrDevice s={})
 
array mlx::core::identity (int n, Dtype dtype, StreamOrDevice s={})
 Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
array mlx::core::identity (int n, StreamOrDevice s={})
 
array mlx::core::tri (int n, int m, int k, Dtype type, StreamOrDevice s={})
 
array mlx::core::tri (int n, Dtype type, StreamOrDevice s={})
 
array mlx::core::tril (array x, int k=0, StreamOrDevice s={})
 
array mlx::core::triu (array x, int k=0, StreamOrDevice s={})
 
array mlx::core::reshape (const array &a, std::vector< int > shape, StreamOrDevice s={})
 Reshape an array to the given shape.
 
array mlx::core::flatten (const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
 Flatten the dimensions in the range [start_axis, end_axis] .
 
array mlx::core::flatten (const array &a, StreamOrDevice s={})
 Flatten the array to 1D.
 
array mlx::core::hadamard_transform (const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
 Multiply the array by the Hadamard matrix of corresponding size.
 
array mlx::core::squeeze (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Remove singleton dimensions at the given axes.
 
array mlx::core::squeeze (const array &a, int axis, StreamOrDevice s={})
 Remove singleton dimensions at the given axis.
 
array mlx::core::squeeze (const array &a, StreamOrDevice s={})
 Remove all singleton dimensions.
 
array mlx::core::expand_dims (const array &a, const std::vector< int > &axes, StreamOrDevice s={})
 Add a singleton dimension at the given axes.
 
array mlx::core::expand_dims (const array &a, int axis, StreamOrDevice s={})
 Add a singleton dimension at the given axis.
 
array mlx::core::slice (const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Slice an array.
 
array mlx::core::slice (const array &a, const std::vector< int > &start, const std::vector< int > &stop, StreamOrDevice s={})
 Slice an array with a stride of 1 in each dimension.
 
array mlx::core::slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
 Update a slice from the source array.
 
array mlx::core::slice_update (const array &src, const array &update, std::vector< int > start, std::vector< int > stop, StreamOrDevice s={})
 Update a slice from the source array with stride 1 in each dimension.
 
std::vector< arraymlx::core::split (const array &a, int num_splits, int axis, StreamOrDevice s={})
 Split an array into sub-arrays along a given axis.
 
std::vector< arraymlx::core::split (const array &a, int num_splits, StreamOrDevice s={})
 
std::vector< arraymlx::core::split (const array &a, const std::vector< int > &indices, int axis, StreamOrDevice s={})
 
std::vector< arraymlx::core::split (const array &a, const std::vector< int > &indices, StreamOrDevice s={})
 
std::vector< arraymlx::core::meshgrid (const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
 A vector of coordinate arrays from coordinate vectors.
 
array mlx::core::clip (const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
 Clip (limit) the values in an array.
 
array mlx::core::concatenate (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Concatenate arrays along a given axis.
 
array mlx::core::concatenate (const std::vector< array > &arrays, StreamOrDevice s={})
 
array mlx::core::stack (const std::vector< array > &arrays, int axis, StreamOrDevice s={})
 Stack arrays along a new axis.
 
array mlx::core::stack (const std::vector< array > &arrays, StreamOrDevice s={})
 
array mlx::core::repeat (const array &arr, int repeats, int axis, StreamOrDevice s={})
 Repeat an array along an axis.
 
array mlx::core::repeat (const array &arr, int repeats, StreamOrDevice s={})
 
array mlx::core::tile (const array &arr, std::vector< int > reps, StreamOrDevice s={})
 
array mlx::core::transpose (const array &a, std::vector< int > axes, StreamOrDevice s={})
 Permutes the dimensions according to the given axes.
 
array mlx::core::transpose (const array &a, std::initializer_list< int > axes, StreamOrDevice s={})
 
array mlx::core::swapaxes (const array &a, int axis1, int axis2, StreamOrDevice s={})
 Swap two axes of an array.
 
array mlx::core::moveaxis (const array &a, int source, int destination, StreamOrDevice s={})
 Move an axis of an array.
 
array mlx::core::pad (const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
 Pad an array with a constant value.
 
array mlx::core::pad (const array &a, const std::vector< std::pair< int, int > > &pad_width, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
 Pad an array with a constant value along all axes.
 
array mlx::core::pad (const array &a, const std::pair< int, int > &pad_width, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
 
array mlx::core::pad (const array &a, int pad_width, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
 
array mlx::core::transpose (const array &a, StreamOrDevice s={})
 Permutes the dimensions in reverse order.
 
array mlx::core::broadcast_to (const array &a, const std::vector< int > &shape, StreamOrDevice s={})
 Broadcast an array to a given shape.
 
std::vector< arraymlx::core::broadcast_arrays (const std::vector< array > &inputs, StreamOrDevice s={})
 Broadcast a vector of arrays against one another.
 
array mlx::core::equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a == b) element-wise.
 
array mlx::core::operator== (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator== (T a, const array &b)
 
template<typename T >
array mlx::core::operator== (const array &a, T b)
 
array mlx::core::not_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns the bool array with (a != b) element-wise.
 
array mlx::core::operator!= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator!= (T a, const array &b)
 
template<typename T >
array mlx::core::operator!= (const array &a, T b)
 
array mlx::core::greater (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a > b) element-wise.
 
array mlx::core::operator> (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator> (T a, const array &b)
 
template<typename T >
array mlx::core::operator> (const array &a, T b)
 
array mlx::core::greater_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a >= b) element-wise.
 
array mlx::core::operator>= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator>= (T a, const array &b)
 
template<typename T >
array mlx::core::operator>= (const array &a, T b)
 
array mlx::core::less (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a < b) element-wise.
 
array mlx::core::operator< (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator< (T a, const array &b)
 
template<typename T >
array mlx::core::operator< (const array &a, T b)
 
array mlx::core::less_equal (const array &a, const array &b, StreamOrDevice s={})
 Returns bool array with (a <= b) element-wise.
 
array mlx::core::operator<= (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator<= (T a, const array &b)
 
template<typename T >
array mlx::core::operator<= (const array &a, T b)
 
array mlx::core::array_equal (const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
 True if two arrays have the same shape and elements.
 
array mlx::core::array_equal (const array &a, const array &b, StreamOrDevice s={})
 
array mlx::core::isnan (const array &a, StreamOrDevice s={})
 
array mlx::core::isinf (const array &a, StreamOrDevice s={})
 
array mlx::core::isfinite (const array &a, StreamOrDevice s={})
 
array mlx::core::isposinf (const array &a, StreamOrDevice s={})
 
array mlx::core::isneginf (const array &a, StreamOrDevice s={})
 
array mlx::core::where (const array &condition, const array &x, const array &y, StreamOrDevice s={})
 Select from x or y depending on condition.
 
array mlx::core::nan_to_num (const array &a, float nan=0.0f, const std::optional< float > &posinf=std::nullopt, const std::optional< float > &neginf=std::nullopt, StreamOrDevice s={})
 Replace NaN and infinities with finite numbers.
 
array mlx::core::all (const array &a, bool keepdims, StreamOrDevice s={})
 True if all elements in the array are true (or non-zero).
 
array mlx::core::all (const array &a, StreamOrDevice s={})
 
array mlx::core::allclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 True if the two arrays are equal within the specified tolerance.
 
array mlx::core::isclose (const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
 Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
 
array mlx::core::all (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array mlx::core::all (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array mlx::core::any (const array &a, bool keepdims, StreamOrDevice s={})
 True if any elements in the array are true (or non-zero).
 
array mlx::core::any (const array &a, StreamOrDevice s={})
 
array mlx::core::any (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axes.
 
array mlx::core::any (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Reduces the input along the given axis.
 
array mlx::core::sum (const array &a, bool keepdims, StreamOrDevice s={})
 Sums the elements of an array.
 
array mlx::core::sum (const array &a, StreamOrDevice s={})
 
array mlx::core::sum (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axes.
 
array mlx::core::sum (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Sums the elements of an array along the given axis.
 
array mlx::core::mean (const array &a, bool keepdims, StreamOrDevice s={})
 Computes the mean of the elements of an array.
 
array mlx::core::mean (const array &a, StreamOrDevice s={})
 
array mlx::core::mean (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axes.
 
array mlx::core::mean (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Computes the mean of the elements of an array along the given axis.
 
array mlx::core::var (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array.
 
array mlx::core::var (const array &a, StreamOrDevice s={})
 
array mlx::core::var (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axes.
 
array mlx::core::var (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the variance of the elements of an array along the given axis.
 
array mlx::core::std (const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array.
 
array mlx::core::std (const array &a, StreamOrDevice s={})
 
array mlx::core::std (const array &a, const std::vector< int > &axes, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviatoin of the elements of an array along the given axes.
 
array mlx::core::std (const array &a, int axis, bool keepdims=false, int ddof=0, StreamOrDevice s={})
 Computes the standard deviation of the elements of an array along the given axis.
 
array mlx::core::prod (const array &a, bool keepdims, StreamOrDevice s={})
 The product of all elements of the array.
 
array mlx::core::prod (const array &a, StreamOrDevice s={})
 
array mlx::core::prod (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axes.
 
array mlx::core::prod (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The product of the elements of an array along the given axis.
 
array mlx::core::max (const array &a, bool keepdims, StreamOrDevice s={})
 The maximum of all elements of the array.
 
array mlx::core::max (const array &a, StreamOrDevice s={})
 
array mlx::core::max (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axes.
 
array mlx::core::max (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The maximum of the elements of an array along the given axis.
 
array mlx::core::min (const array &a, bool keepdims, StreamOrDevice s={})
 The minimum of all elements of the array.
 
array mlx::core::min (const array &a, StreamOrDevice s={})
 
array mlx::core::min (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axes.
 
array mlx::core::min (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The minimum of the elements of an array along the given axis.
 
array mlx::core::argmin (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the minimum value in the array.
 
array mlx::core::argmin (const array &a, StreamOrDevice s={})
 
array mlx::core::argmin (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the minimum values along a given axis.
 
array mlx::core::argmax (const array &a, bool keepdims, StreamOrDevice s={})
 Returns the index of the maximum value in the array.
 
array mlx::core::argmax (const array &a, StreamOrDevice s={})
 
array mlx::core::argmax (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 Returns the indices of the maximum values along a given axis.
 
array mlx::core::sort (const array &a, StreamOrDevice s={})
 Returns a sorted copy of the flattened array.
 
array mlx::core::sort (const array &a, int axis, StreamOrDevice s={})
 Returns a sorted copy of the array along a given axis.
 
array mlx::core::argsort (const array &a, StreamOrDevice s={})
 Returns indices that sort the flattened array.
 
array mlx::core::argsort (const array &a, int axis, StreamOrDevice s={})
 Returns indices that sort the array along a given axis.
 
array mlx::core::partition (const array &a, int kth, StreamOrDevice s={})
 Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
 
array mlx::core::partition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.
 
array mlx::core::argpartition (const array &a, int kth, StreamOrDevice s={})
 Returns indices that partition the flattened array such that the smaller kth elements are first.
 
array mlx::core::argpartition (const array &a, int kth, int axis, StreamOrDevice s={})
 Returns indices that partition the array along a given axis such that the smaller kth elements are first.
 
array mlx::core::topk (const array &a, int k, StreamOrDevice s={})
 Returns topk elements of the flattened array.
 
array mlx::core::topk (const array &a, int k, int axis, StreamOrDevice s={})
 Returns topk elements of the array along a given axis.
 
array mlx::core::logsumexp (const array &a, bool keepdims, StreamOrDevice s={})
 The logsumexp of all elements of the array.
 
array mlx::core::logsumexp (const array &a, StreamOrDevice s={})
 
array mlx::core::logsumexp (const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axes.
 
array mlx::core::logsumexp (const array &a, int axis, bool keepdims=false, StreamOrDevice s={})
 The logsumexp of the elements of an array along the given axis.
 
array mlx::core::abs (const array &a, StreamOrDevice s={})
 Absolute value of elements in an array.
 
array mlx::core::negative (const array &a, StreamOrDevice s={})
 Negate an array.
 
array mlx::core::operator- (const array &a)
 
array mlx::core::sign (const array &a, StreamOrDevice s={})
 The sign of the elements in an array.
 
array mlx::core::logical_not (const array &a, StreamOrDevice s={})
 Logical not of an array.
 
array mlx::core::logical_and (const array &a, const array &b, StreamOrDevice s={})
 Logical and of two arrays.
 
array mlx::core::operator&& (const array &a, const array &b)
 
array mlx::core::logical_or (const array &a, const array &b, StreamOrDevice s={})
 Logical or of two arrays.
 
array mlx::core::operator|| (const array &a, const array &b)
 
array mlx::core::reciprocal (const array &a, StreamOrDevice s={})
 The reciprocal (1/x) of the elements in an array.
 
array mlx::core::add (const array &a, const array &b, StreamOrDevice s={})
 Add two arrays.
 
array mlx::core::operator+ (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator+ (T a, const array &b)
 
template<typename T >
array mlx::core::operator+ (const array &a, T b)
 
array mlx::core::subtract (const array &a, const array &b, StreamOrDevice s={})
 Subtract two arrays.
 
array mlx::core::operator- (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator- (T a, const array &b)
 
template<typename T >
array mlx::core::operator- (const array &a, T b)
 
array mlx::core::multiply (const array &a, const array &b, StreamOrDevice s={})
 Multiply two arrays.
 
array mlx::core::operator* (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator* (T a, const array &b)
 
template<typename T >
array mlx::core::operator* (const array &a, T b)
 
array mlx::core::divide (const array &a, const array &b, StreamOrDevice s={})
 Divide two arrays.
 
array mlx::core::operator/ (const array &a, const array &b)
 
array mlx::core::operator/ (double a, const array &b)
 
array mlx::core::operator/ (const array &a, double b)
 
std::vector< arraymlx::core::divmod (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise quotient and remainder.
 
array mlx::core::floor_divide (const array &a, const array &b, StreamOrDevice s={})
 Compute integer division.
 
array mlx::core::remainder (const array &a, const array &b, StreamOrDevice s={})
 Compute the element-wise remainder of division.
 
array mlx::core::operator% (const array &a, const array &b)
 
template<typename T >
array mlx::core::operator% (T a, const array &b)
 
template<typename T >
array mlx::core::operator% (const array &a, T b)
 
array mlx::core::maximum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise maximum between two arrays.
 
array mlx::core::minimum (const array &a, const array &b, StreamOrDevice s={})
 Element-wise minimum between two arrays.
 
array mlx::core::floor (const array &a, StreamOrDevice s={})
 Floor the element of an array.
 
array mlx::core::ceil (const array &a, StreamOrDevice s={})
 Ceil the element of an array.
 
array mlx::core::square (const array &a, StreamOrDevice s={})
 Square the elements of an array.
 
array mlx::core::exp (const array &a, StreamOrDevice s={})
 Exponential of the elements of an array.
 
array mlx::core::sin (const array &a, StreamOrDevice s={})
 Sine of the elements of an array.
 
array mlx::core::cos (const array &a, StreamOrDevice s={})
 Cosine of the elements of an array.
 
array mlx::core::tan (const array &a, StreamOrDevice s={})
 Tangent of the elements of an array.
 
array mlx::core::arcsin (const array &a, StreamOrDevice s={})
 Arc Sine of the elements of an array.
 
array mlx::core::arccos (const array &a, StreamOrDevice s={})
 Arc Cosine of the elements of an array.
 
array mlx::core::arctan (const array &a, StreamOrDevice s={})
 Arc Tangent of the elements of an array.
 
array mlx::core::arctan2 (const array &a, const array &b, StreamOrDevice s={})
 Inverse tangent of the ratio of two arrays.
 
array mlx::core::sinh (const array &a, StreamOrDevice s={})
 Hyperbolic Sine of the elements of an array.
 
array mlx::core::cosh (const array &a, StreamOrDevice s={})
 Hyperbolic Cosine of the elements of an array.
 
array mlx::core::tanh (const array &a, StreamOrDevice s={})
 Hyperbolic Tangent of the elements of an array.
 
array mlx::core::arcsinh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Sine of the elements of an array.
 
array mlx::core::arccosh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Cosine of the elements of an array.
 
array mlx::core::arctanh (const array &a, StreamOrDevice s={})
 Inverse Hyperbolic Tangent of the elements of an array.
 
array mlx::core::degrees (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Radians to Degrees.
 
array mlx::core::radians (const array &a, StreamOrDevice s={})
 Convert the elements of an array from Degrees to Radians.
 
array mlx::core::log (const array &a, StreamOrDevice s={})
 Natural logarithm of the elements of an array.
 
array mlx::core::log2 (const array &a, StreamOrDevice s={})
 Log base 2 of the elements of an array.
 
array mlx::core::log10 (const array &a, StreamOrDevice s={})
 Log base 10 of the elements of an array.
 
array mlx::core::log1p (const array &a, StreamOrDevice s={})
 Natural logarithm of one plus elements in the array: log(1 + a).
 
array mlx::core::logaddexp (const array &a, const array &b, StreamOrDevice s={})
 Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
 
array mlx::core::sigmoid (const array &a, StreamOrDevice s={})
 Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
 
array mlx::core::erf (const array &a, StreamOrDevice s={})
 Computes the error function of the elements of an array.
 
array mlx::core::erfinv (const array &a, StreamOrDevice s={})
 Computes the inverse error function of the elements of an array.
 
array mlx::core::expm1 (const array &a, StreamOrDevice s={})
 Computes the expm1 function of the elements of an array.
 
array mlx::core::stop_gradient (const array &a, StreamOrDevice s={})
 Stop the flow of gradients.
 
array mlx::core::round (const array &a, int decimals, StreamOrDevice s={})
 Round a floating point number.
 
array mlx::core::round (const array &a, StreamOrDevice s={})
 
array mlx::core::matmul (const array &a, const array &b, StreamOrDevice s={})
 Matrix-matrix multiplication.
 
array mlx::core::gather (const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 Gather array entries given indices and slices.
 
array mlx::core::gather (const array &a, const array &indices, int axis, const std::vector< int > &slice_sizes, StreamOrDevice s={})
 
array mlx::core::take (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array slices at the given indices of the specified axis.
 
array mlx::core::take (const array &a, const array &indices, StreamOrDevice s={})
 Take array entries at the given indices treating the array as flattened.
 
array mlx::core::take_along_axis (const array &a, const array &indices, int axis, StreamOrDevice s={})
 Take array entries given indices along the axis.
 
array mlx::core::scatter (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter updates to the given indices.
 
array mlx::core::scatter (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_add (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and add updates to given indices.
 
array mlx::core::scatter_add (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_prod (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and prod updates to given indices.
 
array mlx::core::scatter_prod (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_max (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and max updates to given linear indices.
 
array mlx::core::scatter_max (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::scatter_min (const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
 Scatter and min updates to given linear indices.
 
array mlx::core::scatter_min (const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})
 
array mlx::core::sqrt (const array &a, StreamOrDevice s={})
 Square root the elements of an array.
 
array mlx::core::rsqrt (const array &a, StreamOrDevice s={})
 Square root and reciprocal the elements of an array.
 
array mlx::core::softmax (const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::softmax (const array &a, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::softmax (const array &a, int axis, bool precise=false, StreamOrDevice s={})
 Softmax of an array.
 
array mlx::core::power (const array &a, const array &b, StreamOrDevice s={})
 Raise elements of a to the power of b element-wise.
 
array mlx::core::cumsum (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative sum of an array.
 
array mlx::core::cumprod (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative product of an array.
 
array mlx::core::cummax (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative max of an array.
 
array mlx::core::cummin (const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
 Cumulative min of an array.
 
array mlx::core::conv_general (array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array mlx::core::conv_general (const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
 General convolution with a filter.
 
array mlx::core::conv1d (const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
 1D convolution with a filter
 
array mlx::core::conv2d (const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
 2D convolution with a filter
 
array mlx::core::conv3d (const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
 3D convolution with a filter
 
array mlx::core::quantized_matmul (const array &x, const array &w, const array &scales, const array &biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantized matmul multiplies x with a quantized matrix w.
 
std::tuple< array, array, arraymlx::core::quantize (const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
 Quantize a matrix along its last axis.
 
array mlx::core::dequantize (const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
 Dequantize a matrix produced by quantize()
 
array mlx::core::gather_qmm (const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
 Compute matrix products with matrix-level gather.
 
array mlx::core::tensordot (const array &a, const array &b, const int axis=2, StreamOrDevice s={})
 Returns a contraction of a and b over multiple dimensions.
 
array mlx::core::tensordot (const array &a, const array &b, const std::vector< int > &axes_a, const std::vector< int > &axes_b, StreamOrDevice s={})
 
array mlx::core::outer (const array &a, const array &b, StreamOrDevice s={})
 Compute the outer product of two vectors.
 
array mlx::core::inner (const array &a, const array &b, StreamOrDevice s={})
 Compute the inner product of two vectors.
 
array mlx::core::addmm (array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
 Compute D = beta * C + alpha * (A @ B)
 
array mlx::core::block_masked_mm (array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
 Compute matrix product with block masking.
 
array mlx::core::gather_mm (array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
 Compute matrix product with matrix-level gather.
 
array mlx::core::diagonal (const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
 Extract a diagonal or construct a diagonal array.
 
array mlx::core::diag (const array &a, int k=0, StreamOrDevice s={})
 Extract diagonal from a 2d array or create a diagonal matrix.
 
array mlx::core::trace (const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
 Return the sum along a specified diagonal in the given array.
 
array mlx::core::trace (const array &a, int offset, int axis1, int axis2, StreamOrDevice s={})
 
array mlx::core::trace (const array &a, StreamOrDevice s={})
 
std::vector< arraymlx::core::depends (const std::vector< array > &inputs, const std::vector< array > &dependencies)
 Implements the identity function but allows injecting dependencies to other arrays.
 
array mlx::core::atleast_1d (const array &a, StreamOrDevice s={})
 convert an array to an atleast ndim array
 
std::vector< arraymlx::core::atleast_1d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::atleast_2d (const array &a, StreamOrDevice s={})
 
std::vector< arraymlx::core::atleast_2d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::atleast_3d (const array &a, StreamOrDevice s={})
 
std::vector< arraymlx::core::atleast_3d (const std::vector< array > &a, StreamOrDevice s={})
 
array mlx::core::number_of_elements (const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
 Extract the number of elements along some axes as a scalar array.
 
array mlx::core::conjugate (const array &a, StreamOrDevice s={})
 
array mlx::core::bitwise_and (const array &a, const array &b, StreamOrDevice s={})
 Bitwise and.
 
array mlx::core::operator& (const array &a, const array &b)
 
array mlx::core::bitwise_or (const array &a, const array &b, StreamOrDevice s={})
 Bitwise inclusive or.
 
array mlx::core::operator| (const array &a, const array &b)
 
array mlx::core::bitwise_xor (const array &a, const array &b, StreamOrDevice s={})
 Bitwise exclusive or.
 
array mlx::core::operator^ (const array &a, const array &b)
 
array mlx::core::left_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the left.
 
array mlx::core::operator<< (const array &a, const array &b)
 
array mlx::core::right_shift (const array &a, const array &b, StreamOrDevice s={})
 Shift bits to the right.
 
array mlx::core::operator>> (const array &a, const array &b)
 
array mlx::core::view (const array &a, const Dtype &dtype, StreamOrDevice s={})
 

Detailed Description

Function Documentation

◆ abs()

array mlx::core::abs ( const array & a,
StreamOrDevice s = {} )

Absolute value of elements in an array.

◆ add()

array mlx::core::add ( const array & a,
const array & b,
StreamOrDevice s = {} )

Add two arrays.

◆ addmm()

array mlx::core::addmm ( array c,
array a,
array b,
const float & alpha = 1.f,
const float & beta = 1.f,
StreamOrDevice s = {} )

Compute D = beta * C + alpha * (A @ B)

◆ all() [1/4]

array mlx::core::all ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

True if all elements in the array are true (or non-zero).

◆ all() [2/4]

array mlx::core::all ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axes.

An output value is true if all the corresponding inputs are true.

◆ all() [3/4]

array mlx::core::all ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axis.

An output value is true if all the corresponding inputs are true.

◆ all() [4/4]

array mlx::core::all ( const array & a,
StreamOrDevice s = {} )
inline

◆ allclose()

array mlx::core::allclose ( const array & a,
const array & b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {} )

True if the two arrays are equal within the specified tolerance.

◆ any() [1/4]

array mlx::core::any ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

True if any elements in the array are true (or non-zero).

◆ any() [2/4]

array mlx::core::any ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axes.

An output value is true if any of the corresponding inputs are true.

◆ any() [3/4]

array mlx::core::any ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Reduces the input along the given axis.

An output value is true if any of the corresponding inputs are true.

◆ any() [4/4]

array mlx::core::any ( const array & a,
StreamOrDevice s = {} )
inline

◆ arange() [1/9]

array mlx::core::arange ( double start,
double stop,
double step,
Dtype dtype,
StreamOrDevice s = {} )

A 1D array of numbers starting at start (optional), stopping at stop, stepping by step (optional).

◆ arange() [2/9]

array mlx::core::arange ( double start,
double stop,
double step,
StreamOrDevice s = {} )

◆ arange() [3/9]

array mlx::core::arange ( double start,
double stop,
Dtype dtype,
StreamOrDevice s = {} )

◆ arange() [4/9]

array mlx::core::arange ( double start,
double stop,
StreamOrDevice s = {} )

◆ arange() [5/9]

array mlx::core::arange ( double stop,
Dtype dtype,
StreamOrDevice s = {} )

◆ arange() [6/9]

array mlx::core::arange ( double stop,
StreamOrDevice s = {} )

◆ arange() [7/9]

array mlx::core::arange ( int start,
int stop,
int step,
StreamOrDevice s = {} )

◆ arange() [8/9]

array mlx::core::arange ( int start,
int stop,
StreamOrDevice s = {} )

◆ arange() [9/9]

array mlx::core::arange ( int stop,
StreamOrDevice s = {} )

◆ arccos()

array mlx::core::arccos ( const array & a,
StreamOrDevice s = {} )

Arc Cosine of the elements of an array.

◆ arccosh()

array mlx::core::arccosh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Cosine of the elements of an array.

◆ arcsin()

array mlx::core::arcsin ( const array & a,
StreamOrDevice s = {} )

Arc Sine of the elements of an array.

◆ arcsinh()

array mlx::core::arcsinh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Sine of the elements of an array.

◆ arctan()

array mlx::core::arctan ( const array & a,
StreamOrDevice s = {} )

Arc Tangent of the elements of an array.

◆ arctan2()

array mlx::core::arctan2 ( const array & a,
const array & b,
StreamOrDevice s = {} )

Inverse tangent of the ratio of two arrays.

◆ arctanh()

array mlx::core::arctanh ( const array & a,
StreamOrDevice s = {} )

Inverse Hyperbolic Tangent of the elements of an array.

◆ argmax() [1/3]

array mlx::core::argmax ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Returns the index of the maximum value in the array.

◆ argmax() [2/3]

array mlx::core::argmax ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Returns the indices of the maximum values along a given axis.

◆ argmax() [3/3]

array mlx::core::argmax ( const array & a,
StreamOrDevice s = {} )
inline

◆ argmin() [1/3]

array mlx::core::argmin ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Returns the index of the minimum value in the array.

◆ argmin() [2/3]

array mlx::core::argmin ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Returns the indices of the minimum values along a given axis.

◆ argmin() [3/3]

array mlx::core::argmin ( const array & a,
StreamOrDevice s = {} )
inline

◆ argpartition() [1/2]

array mlx::core::argpartition ( const array & a,
int kth,
int axis,
StreamOrDevice s = {} )

Returns indices that partition the array along a given axis such that the smaller kth elements are first.

◆ argpartition() [2/2]

array mlx::core::argpartition ( const array & a,
int kth,
StreamOrDevice s = {} )

Returns indices that partition the flattened array such that the smaller kth elements are first.

◆ argsort() [1/2]

array mlx::core::argsort ( const array & a,
int axis,
StreamOrDevice s = {} )

Returns indices that sort the array along a given axis.

◆ argsort() [2/2]

array mlx::core::argsort ( const array & a,
StreamOrDevice s = {} )

Returns indices that sort the flattened array.

◆ array_equal() [1/2]

array mlx::core::array_equal ( const array & a,
const array & b,
bool equal_nan,
StreamOrDevice s = {} )

True if two arrays have the same shape and elements.

◆ array_equal() [2/2]

array mlx::core::array_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )
inline

◆ as_strided()

array mlx::core::as_strided ( array a,
std::vector< int > shape,
std::vector< size_t > strides,
size_t offset,
StreamOrDevice s = {} )

Create a view of an array with the given shape and strides.

◆ astype()

array mlx::core::astype ( array a,
Dtype dtype,
StreamOrDevice s = {} )

Convert an array to the given data type.

◆ atleast_1d() [1/2]

array mlx::core::atleast_1d ( const array & a,
StreamOrDevice s = {} )

convert an array to an atleast ndim array

◆ atleast_1d() [2/2]

std::vector< array > mlx::core::atleast_1d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ atleast_2d() [1/2]

array mlx::core::atleast_2d ( const array & a,
StreamOrDevice s = {} )

◆ atleast_2d() [2/2]

std::vector< array > mlx::core::atleast_2d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ atleast_3d() [1/2]

array mlx::core::atleast_3d ( const array & a,
StreamOrDevice s = {} )

◆ atleast_3d() [2/2]

std::vector< array > mlx::core::atleast_3d ( const std::vector< array > & a,
StreamOrDevice s = {} )

◆ bitwise_and()

array mlx::core::bitwise_and ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise and.

◆ bitwise_or()

array mlx::core::bitwise_or ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise inclusive or.

◆ bitwise_xor()

array mlx::core::bitwise_xor ( const array & a,
const array & b,
StreamOrDevice s = {} )

Bitwise exclusive or.

◆ block_masked_mm()

array mlx::core::block_masked_mm ( array a,
array b,
int block_size,
std::optional< array > mask_out = std::nullopt,
std::optional< array > mask_lhs = std::nullopt,
std::optional< array > mask_rhs = std::nullopt,
StreamOrDevice s = {} )

Compute matrix product with block masking.

◆ broadcast_arrays()

std::vector< array > mlx::core::broadcast_arrays ( const std::vector< array > & inputs,
StreamOrDevice s = {} )

Broadcast a vector of arrays against one another.

◆ broadcast_to()

array mlx::core::broadcast_to ( const array & a,
const std::vector< int > & shape,
StreamOrDevice s = {} )

Broadcast an array to a given shape.

◆ ceil()

array mlx::core::ceil ( const array & a,
StreamOrDevice s = {} )

Ceil the element of an array.

◆ clip()

array mlx::core::clip ( const array & a,
const std::optional< array > & a_min = std::nullopt,
const std::optional< array > & a_max = std::nullopt,
StreamOrDevice s = {} )

Clip (limit) the values in an array.

◆ concatenate() [1/2]

array mlx::core::concatenate ( const std::vector< array > & arrays,
int axis,
StreamOrDevice s = {} )

Concatenate arrays along a given axis.

◆ concatenate() [2/2]

array mlx::core::concatenate ( const std::vector< array > & arrays,
StreamOrDevice s = {} )

◆ conjugate()

array mlx::core::conjugate ( const array & a,
StreamOrDevice s = {} )

◆ conv1d()

array mlx::core::conv1d ( const array & input,
const array & weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {} )

1D convolution with a filter

◆ conv2d()

array mlx::core::conv2d ( const array & input,
const array & weight,
const std::pair< int, int > & stride = {1, 1},
const std::pair< int, int > & padding = {0, 0},
const std::pair< int, int > & dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {} )

2D convolution with a filter

◆ conv3d()

array mlx::core::conv3d ( const array & input,
const array & weight,
const std::tuple< int, int, int > & stride = {1, 1, 1},
const std::tuple< int, int, int > & padding = {0, 0, 0},
const std::tuple< int, int, int > & dilation = {1, 1, 1},
int groups = 1,
StreamOrDevice s = {} )

3D convolution with a filter

◆ conv_general() [1/2]

array mlx::core::conv_general ( array input,
array weight,
std::vector< int > stride = {},
std::vector< int > padding_lo = {},
std::vector< int > padding_hi = {},
std::vector< int > kernel_dilation = {},
std::vector< int > input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {} )

General convolution with a filter.

◆ conv_general() [2/2]

array mlx::core::conv_general ( const array & input,
const array & weight,
std::vector< int > stride = {},
std::vector< int > padding = {},
std::vector< int > kernel_dilation = {},
std::vector< int > input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {} )
inline

General convolution with a filter.

◆ copy()

array mlx::core::copy ( array a,
StreamOrDevice s = {} )

Copy another array.

◆ cos()

array mlx::core::cos ( const array & a,
StreamOrDevice s = {} )

Cosine of the elements of an array.

◆ cosh()

array mlx::core::cosh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Cosine of the elements of an array.

◆ cummax()

array mlx::core::cummax ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative max of an array.

◆ cummin()

array mlx::core::cummin ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative min of an array.

◆ cumprod()

array mlx::core::cumprod ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative product of an array.

◆ cumsum()

array mlx::core::cumsum ( const array & a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {} )

Cumulative sum of an array.

◆ degrees()

array mlx::core::degrees ( const array & a,
StreamOrDevice s = {} )

Convert the elements of an array from Radians to Degrees.

◆ depends()

std::vector< array > mlx::core::depends ( const std::vector< array > & inputs,
const std::vector< array > & dependencies )

Implements the identity function but allows injecting dependencies to other arrays.

This ensures that these other arrays will have been computed when the outputs of this function are computed.

◆ dequantize()

array mlx::core::dequantize ( const array & w,
const array & scales,
const array & biases,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Dequantize a matrix produced by quantize()

◆ diag()

array mlx::core::diag ( const array & a,
int k = 0,
StreamOrDevice s = {} )

Extract diagonal from a 2d array or create a diagonal matrix.

◆ diagonal()

array mlx::core::diagonal ( const array & a,
int offset = 0,
int axis1 = 0,
int axis2 = 1,
StreamOrDevice s = {} )

Extract a diagonal or construct a diagonal array.

◆ divide()

array mlx::core::divide ( const array & a,
const array & b,
StreamOrDevice s = {} )

Divide two arrays.

◆ divmod()

std::vector< array > mlx::core::divmod ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the element-wise quotient and remainder.

◆ equal()

array mlx::core::equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns the bool array with (a == b) element-wise.

◆ erf()

array mlx::core::erf ( const array & a,
StreamOrDevice s = {} )

Computes the error function of the elements of an array.

◆ erfinv()

array mlx::core::erfinv ( const array & a,
StreamOrDevice s = {} )

Computes the inverse error function of the elements of an array.

◆ exp()

array mlx::core::exp ( const array & a,
StreamOrDevice s = {} )

Exponential of the elements of an array.

◆ expand_dims() [1/2]

array mlx::core::expand_dims ( const array & a,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Add a singleton dimension at the given axes.

◆ expand_dims() [2/2]

array mlx::core::expand_dims ( const array & a,
int axis,
StreamOrDevice s = {} )

Add a singleton dimension at the given axis.

◆ expm1()

array mlx::core::expm1 ( const array & a,
StreamOrDevice s = {} )

Computes the expm1 function of the elements of an array.

◆ eye() [1/5]

array mlx::core::eye ( int n,
Dtype dtype,
StreamOrDevice s = {} )
inline

◆ eye() [2/5]

array mlx::core::eye ( int n,
int m,
int k,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere else.

◆ eye() [3/5]

array mlx::core::eye ( int n,
int m,
int k,
StreamOrDevice s = {} )
inline

◆ eye() [4/5]

array mlx::core::eye ( int n,
int m,
StreamOrDevice s = {} )
inline

◆ eye() [5/5]

array mlx::core::eye ( int n,
StreamOrDevice s = {} )
inline

◆ flatten() [1/2]

array mlx::core::flatten ( const array & a,
int start_axis,
int end_axis = -1,
StreamOrDevice s = {} )

Flatten the dimensions in the range [start_axis, end_axis] .

◆ flatten() [2/2]

array mlx::core::flatten ( const array & a,
StreamOrDevice s = {} )

Flatten the array to 1D.

◆ floor()

array mlx::core::floor ( const array & a,
StreamOrDevice s = {} )

Floor the element of an array.

◆ floor_divide()

array mlx::core::floor_divide ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute integer division.

Equivalent to doing floor(a / x).

◆ full() [1/4]

array mlx::core::full ( std::vector< int > shape,
array vals,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with the given value(s).

◆ full() [2/4]

array mlx::core::full ( std::vector< int > shape,
array vals,
StreamOrDevice s = {} )

◆ full() [3/4]

template<typename T >
array mlx::core::full ( std::vector< int > shape,
T val,
Dtype dtype,
StreamOrDevice s = {} )

◆ full() [4/4]

template<typename T >
array mlx::core::full ( std::vector< int > shape,
T val,
StreamOrDevice s = {} )

◆ gather() [1/2]

array mlx::core::gather ( const array & a,
const array & indices,
int axis,
const std::vector< int > & slice_sizes,
StreamOrDevice s = {} )
inline

◆ gather() [2/2]

array mlx::core::gather ( const array & a,
const std::vector< array > & indices,
const std::vector< int > & axes,
const std::vector< int > & slice_sizes,
StreamOrDevice s = {} )

Gather array entries given indices and slices.

◆ gather_mm()

array mlx::core::gather_mm ( array a,
array b,
std::optional< array > lhs_indices = std::nullopt,
std::optional< array > rhs_indices = std::nullopt,
StreamOrDevice s = {} )

Compute matrix product with matrix-level gather.

◆ gather_qmm()

array mlx::core::gather_qmm ( const array & x,
const array & w,
const array & scales,
const array & biases,
std::optional< array > lhs_indices = std::nullopt,
std::optional< array > rhs_indices = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Compute matrix products with matrix-level gather.

◆ greater()

array mlx::core::greater ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a > b) element-wise.

◆ greater_equal()

array mlx::core::greater_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a >= b) element-wise.

◆ hadamard_transform()

array mlx::core::hadamard_transform ( const array & a,
std::optional< float > scale = std::nullopt,
StreamOrDevice s = {} )

Multiply the array by the Hadamard matrix of corresponding size.

◆ identity() [1/2]

array mlx::core::identity ( int n,
Dtype dtype,
StreamOrDevice s = {} )

Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.

◆ identity() [2/2]

array mlx::core::identity ( int n,
StreamOrDevice s = {} )
inline

◆ inner()

array mlx::core::inner ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the inner product of two vectors.

◆ isclose()

array mlx::core::isclose ( const array & a,
const array & b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {} )

Returns a boolean array where two arrays are element-wise equal within the specified tolerance.

◆ isfinite()

array mlx::core::isfinite ( const array & a,
StreamOrDevice s = {} )

◆ isinf()

array mlx::core::isinf ( const array & a,
StreamOrDevice s = {} )

◆ isnan()

array mlx::core::isnan ( const array & a,
StreamOrDevice s = {} )

◆ isneginf()

array mlx::core::isneginf ( const array & a,
StreamOrDevice s = {} )

◆ isposinf()

array mlx::core::isposinf ( const array & a,
StreamOrDevice s = {} )

◆ left_shift()

array mlx::core::left_shift ( const array & a,
const array & b,
StreamOrDevice s = {} )

Shift bits to the left.

◆ less()

array mlx::core::less ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a < b) element-wise.

◆ less_equal()

array mlx::core::less_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns bool array with (a <= b) element-wise.

◆ linspace()

array mlx::core::linspace ( double start,
double stop,
int num = 50,
Dtype dtype = float32,
StreamOrDevice s = {} )

A 1D array of num evenly spaced numbers in the range [start, stop]

◆ log()

array mlx::core::log ( const array & a,
StreamOrDevice s = {} )

Natural logarithm of the elements of an array.

◆ log10()

array mlx::core::log10 ( const array & a,
StreamOrDevice s = {} )

Log base 10 of the elements of an array.

◆ log1p()

array mlx::core::log1p ( const array & a,
StreamOrDevice s = {} )

Natural logarithm of one plus elements in the array: log(1 + a).

◆ log2()

array mlx::core::log2 ( const array & a,
StreamOrDevice s = {} )

Log base 2 of the elements of an array.

◆ logaddexp()

array mlx::core::logaddexp ( const array & a,
const array & b,
StreamOrDevice s = {} )

Log-add-exp of one elements in the array: log(exp(a) + exp(b)).

◆ logical_and()

array mlx::core::logical_and ( const array & a,
const array & b,
StreamOrDevice s = {} )

Logical and of two arrays.

◆ logical_not()

array mlx::core::logical_not ( const array & a,
StreamOrDevice s = {} )

Logical not of an array.

◆ logical_or()

array mlx::core::logical_or ( const array & a,
const array & b,
StreamOrDevice s = {} )

Logical or of two arrays.

◆ logsumexp() [1/4]

array mlx::core::logsumexp ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The logsumexp of all elements of the array.

◆ logsumexp() [2/4]

array mlx::core::logsumexp ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The logsumexp of the elements of an array along the given axes.

◆ logsumexp() [3/4]

array mlx::core::logsumexp ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The logsumexp of the elements of an array along the given axis.

◆ logsumexp() [4/4]

array mlx::core::logsumexp ( const array & a,
StreamOrDevice s = {} )
inline

◆ matmul()

array mlx::core::matmul ( const array & a,
const array & b,
StreamOrDevice s = {} )

Matrix-matrix multiplication.

◆ max() [1/4]

array mlx::core::max ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The maximum of all elements of the array.

◆ max() [2/4]

array mlx::core::max ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The maximum of the elements of an array along the given axes.

◆ max() [3/4]

array mlx::core::max ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The maximum of the elements of an array along the given axis.

◆ max() [4/4]

array mlx::core::max ( const array & a,
StreamOrDevice s = {} )
inline

◆ maximum()

array mlx::core::maximum ( const array & a,
const array & b,
StreamOrDevice s = {} )

Element-wise maximum between two arrays.

◆ mean() [1/4]

array mlx::core::mean ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Computes the mean of the elements of an array.

◆ mean() [2/4]

array mlx::core::mean ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Computes the mean of the elements of an array along the given axes.

◆ mean() [3/4]

array mlx::core::mean ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Computes the mean of the elements of an array along the given axis.

◆ mean() [4/4]

array mlx::core::mean ( const array & a,
StreamOrDevice s = {} )
inline

◆ meshgrid()

std::vector< array > mlx::core::meshgrid ( const std::vector< array > & arrays,
bool sparse = false,
std::string indexing = "xy",
StreamOrDevice s = {} )

A vector of coordinate arrays from coordinate vectors.

◆ min() [1/4]

array mlx::core::min ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The minimum of all elements of the array.

◆ min() [2/4]

array mlx::core::min ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The minimum of the elements of an array along the given axes.

◆ min() [3/4]

array mlx::core::min ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The minimum of the elements of an array along the given axis.

◆ min() [4/4]

array mlx::core::min ( const array & a,
StreamOrDevice s = {} )
inline

◆ minimum()

array mlx::core::minimum ( const array & a,
const array & b,
StreamOrDevice s = {} )

Element-wise minimum between two arrays.

◆ moveaxis()

array mlx::core::moveaxis ( const array & a,
int source,
int destination,
StreamOrDevice s = {} )

Move an axis of an array.

◆ multiply()

array mlx::core::multiply ( const array & a,
const array & b,
StreamOrDevice s = {} )

Multiply two arrays.

◆ nan_to_num()

array mlx::core::nan_to_num ( const array & a,
float nan = 0.0f,
const std::optional< float > & posinf = std::nullopt,
const std::optional< float > & neginf = std::nullopt,
StreamOrDevice s = {} )

Replace NaN and infinities with finite numbers.

◆ negative()

array mlx::core::negative ( const array & a,
StreamOrDevice s = {} )

Negate an array.

◆ not_equal()

array mlx::core::not_equal ( const array & a,
const array & b,
StreamOrDevice s = {} )

Returns the bool array with (a != b) element-wise.

◆ number_of_elements()

array mlx::core::number_of_elements ( const array & a,
std::vector< int > axes,
bool inverted,
Dtype dtype = int32,
StreamOrDevice s = {} )

Extract the number of elements along some axes as a scalar array.

Used to allow shape dependent shapeless compilation (pun intended).

◆ ones() [1/2]

array mlx::core::ones ( const std::vector< int > & shape,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with ones.

◆ ones() [2/2]

array mlx::core::ones ( const std::vector< int > & shape,
StreamOrDevice s = {} )
inline

◆ ones_like()

array mlx::core::ones_like ( const array & a,
StreamOrDevice s = {} )

◆ operator!=() [1/3]

array mlx::core::operator!= ( const array & a,
const array & b )
inline

◆ operator!=() [2/3]

template<typename T >
array mlx::core::operator!= ( const array & a,
T b )

◆ operator!=() [3/3]

template<typename T >
array mlx::core::operator!= ( T a,
const array & b )

◆ operator%() [1/3]

array mlx::core::operator% ( const array & a,
const array & b )

◆ operator%() [2/3]

template<typename T >
array mlx::core::operator% ( const array & a,
T b )

◆ operator%() [3/3]

template<typename T >
array mlx::core::operator% ( T a,
const array & b )

◆ operator&()

array mlx::core::operator& ( const array & a,
const array & b )

◆ operator&&()

array mlx::core::operator&& ( const array & a,
const array & b )

◆ operator*() [1/3]

array mlx::core::operator* ( const array & a,
const array & b )

◆ operator*() [2/3]

template<typename T >
array mlx::core::operator* ( const array & a,
T b )

◆ operator*() [3/3]

template<typename T >
array mlx::core::operator* ( T a,
const array & b )

◆ operator+() [1/3]

array mlx::core::operator+ ( const array & a,
const array & b )

◆ operator+() [2/3]

template<typename T >
array mlx::core::operator+ ( const array & a,
T b )

◆ operator+() [3/3]

template<typename T >
array mlx::core::operator+ ( T a,
const array & b )

◆ operator-() [1/4]

array mlx::core::operator- ( const array & a)

◆ operator-() [2/4]

array mlx::core::operator- ( const array & a,
const array & b )

◆ operator-() [3/4]

template<typename T >
array mlx::core::operator- ( const array & a,
T b )

◆ operator-() [4/4]

template<typename T >
array mlx::core::operator- ( T a,
const array & b )

◆ operator/() [1/3]

array mlx::core::operator/ ( const array & a,
const array & b )

◆ operator/() [2/3]

array mlx::core::operator/ ( const array & a,
double b )

◆ operator/() [3/3]

array mlx::core::operator/ ( double a,
const array & b )

◆ operator<() [1/3]

array mlx::core::operator< ( const array & a,
const array & b )
inline

◆ operator<() [2/3]

template<typename T >
array mlx::core::operator< ( const array & a,
T b )

◆ operator<() [3/3]

template<typename T >
array mlx::core::operator< ( T a,
const array & b )

◆ operator<<()

array mlx::core::operator<< ( const array & a,
const array & b )

◆ operator<=() [1/3]

array mlx::core::operator<= ( const array & a,
const array & b )
inline

◆ operator<=() [2/3]

template<typename T >
array mlx::core::operator<= ( const array & a,
T b )

◆ operator<=() [3/3]

template<typename T >
array mlx::core::operator<= ( T a,
const array & b )

◆ operator==() [1/3]

array mlx::core::operator== ( const array & a,
const array & b )
inline

◆ operator==() [2/3]

template<typename T >
array mlx::core::operator== ( const array & a,
T b )

◆ operator==() [3/3]

template<typename T >
array mlx::core::operator== ( T a,
const array & b )

◆ operator>() [1/3]

array mlx::core::operator> ( const array & a,
const array & b )
inline

◆ operator>() [2/3]

template<typename T >
array mlx::core::operator> ( const array & a,
T b )

◆ operator>() [3/3]

template<typename T >
array mlx::core::operator> ( T a,
const array & b )

◆ operator>=() [1/3]

array mlx::core::operator>= ( const array & a,
const array & b )
inline

◆ operator>=() [2/3]

template<typename T >
array mlx::core::operator>= ( const array & a,
T b )

◆ operator>=() [3/3]

template<typename T >
array mlx::core::operator>= ( T a,
const array & b )

◆ operator>>()

array mlx::core::operator>> ( const array & a,
const array & b )

◆ operator^()

array mlx::core::operator^ ( const array & a,
const array & b )

◆ operator|()

array mlx::core::operator| ( const array & a,
const array & b )

◆ operator||()

array mlx::core::operator|| ( const array & a,
const array & b )

◆ outer()

array mlx::core::outer ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the outer product of two vectors.

◆ pad() [1/4]

array mlx::core::pad ( const array & a,
const std::pair< int, int > & pad_width,
const array & pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {} )

◆ pad() [2/4]

array mlx::core::pad ( const array & a,
const std::vector< int > & axes,
const std::vector< int > & low_pad_size,
const std::vector< int > & high_pad_size,
const array & pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {} )

Pad an array with a constant value.

◆ pad() [3/4]

array mlx::core::pad ( const array & a,
const std::vector< std::pair< int, int > > & pad_width,
const array & pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {} )

Pad an array with a constant value along all axes.

◆ pad() [4/4]

array mlx::core::pad ( const array & a,
int pad_width,
const array & pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {} )

◆ partition() [1/2]

array mlx::core::partition ( const array & a,
int kth,
int axis,
StreamOrDevice s = {} )

Returns a partitioned copy of the array along a given axis such that the smaller kth elements are first.

◆ partition() [2/2]

array mlx::core::partition ( const array & a,
int kth,
StreamOrDevice s = {} )

Returns a partitioned copy of the flattened array such that the smaller kth elements are first.

◆ power()

array mlx::core::power ( const array & a,
const array & b,
StreamOrDevice s = {} )

Raise elements of a to the power of b element-wise.

◆ prod() [1/4]

array mlx::core::prod ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

The product of all elements of the array.

◆ prod() [2/4]

array mlx::core::prod ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

The product of the elements of an array along the given axes.

◆ prod() [3/4]

array mlx::core::prod ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

The product of the elements of an array along the given axis.

◆ prod() [4/4]

array mlx::core::prod ( const array & a,
StreamOrDevice s = {} )
inline

◆ quantize()

std::tuple< array, array, array > mlx::core::quantize ( const array & w,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Quantize a matrix along its last axis.

◆ quantized_matmul()

array mlx::core::quantized_matmul ( const array & x,
const array & w,
const array & scales,
const array & biases,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {} )

Quantized matmul multiplies x with a quantized matrix w.

◆ radians()

array mlx::core::radians ( const array & a,
StreamOrDevice s = {} )

Convert the elements of an array from Degrees to Radians.

◆ reciprocal()

array mlx::core::reciprocal ( const array & a,
StreamOrDevice s = {} )

The reciprocal (1/x) of the elements in an array.

◆ remainder()

array mlx::core::remainder ( const array & a,
const array & b,
StreamOrDevice s = {} )

Compute the element-wise remainder of division.

◆ repeat() [1/2]

array mlx::core::repeat ( const array & arr,
int repeats,
int axis,
StreamOrDevice s = {} )

Repeat an array along an axis.

◆ repeat() [2/2]

array mlx::core::repeat ( const array & arr,
int repeats,
StreamOrDevice s = {} )

◆ reshape()

array mlx::core::reshape ( const array & a,
std::vector< int > shape,
StreamOrDevice s = {} )

Reshape an array to the given shape.

◆ right_shift()

array mlx::core::right_shift ( const array & a,
const array & b,
StreamOrDevice s = {} )

Shift bits to the right.

◆ round() [1/2]

array mlx::core::round ( const array & a,
int decimals,
StreamOrDevice s = {} )

Round a floating point number.

◆ round() [2/2]

array mlx::core::round ( const array & a,
StreamOrDevice s = {} )
inline

◆ rsqrt()

array mlx::core::rsqrt ( const array & a,
StreamOrDevice s = {} )

Square root and reciprocal the elements of an array.

◆ scatter() [1/2]

array mlx::core::scatter ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter() [2/2]

array mlx::core::scatter ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter updates to the given indices.

The parameters indices and axes determine the locations of a that are updated with the values in updates. Assuming 1-d indices for simplicity, indices[i] are the indices on axis axes[i] to which the values in updates will be applied. Note each array in indices is assigned to a corresponding axis and hence indices.size() == axes.size(). If an index/axis pair is not provided then indices along that axis are assumed to be zero.

Note the rank of updates must be equal to the sum of the rank of the broadcasted indices and the rank of a. In other words, assuming the arrays in indices have the same shape, updates.ndim() == indices[0].ndim() + a.ndim(). The leading dimensions of updates correspond to the indices, and the remaining a.ndim() dimensions are the values that will be applied to the given location in a.

For example:

auto in = zeros({4, 4}, float32);
auto indices = array({2});
auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
std::vector<int> axes{0};
auto out = scatter(in, {indices}, updates, axes);
array zeros(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
array reshape(const array &a, std::vector< int > shape, StreamOrDevice s={})
Reshape an array to the given shape.
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
constexpr Dtype float32
Definition dtype.h:71

will produce:

array([[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 2, 0, 0],
[0, 0, 0, 0]], dtype=float32)
Definition array.h:20

This scatters the two-element row vector [1, 2] starting at the (2, 0) position of a.

Adding another element to indices will scatter into another location of a. We also have to add an another update for the new index:

auto in = zeros({4, 4}, float32);
auto indices = array({2, 0});
auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
std::vector<int> axes{0};
auto out = scatter(in, {indices}, updates, axes):

will produce:

array([[3, 4, 0, 0],
[0, 0, 0, 0],
[1, 2, 0, 0],
[0, 0, 0, 0]], dtype=float32)

To control the scatter location on an additional axis, add another index array to indices and another axis to axes:

auto in = zeros({4, 4}, float32);
auto indices = std::vector{array({2, 0}), array({1, 2})};
auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
std::vector<int> axes{0, 1};
auto out = scatter(in, indices, updates, axes);

will produce:

array([[0, 0, 3, 4],
[0, 0, 0, 0],
[0, 1, 2, 0],
[0, 0, 0, 0]], dtype=float32)

Items in indices are broadcasted together. This means:

auto indices = std::vector{array({2, 0}), array({1})};

is equivalent to:

auto indices = std::vector{array({2, 0}), array({1, 1})};

Note, scatter does not perform bounds checking on the indices and updates. Out-of-bounds accesses on a are undefined and typically result in unintended or invalid memory writes.

◆ scatter_add() [1/2]

array mlx::core::scatter_add ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_add() [2/2]

array mlx::core::scatter_add ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and add updates to given indices.

◆ scatter_max() [1/2]

array mlx::core::scatter_max ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_max() [2/2]

array mlx::core::scatter_max ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and max updates to given linear indices.

◆ scatter_min() [1/2]

array mlx::core::scatter_min ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_min() [2/2]

array mlx::core::scatter_min ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and min updates to given linear indices.

◆ scatter_prod() [1/2]

array mlx::core::scatter_prod ( const array & a,
const array & indices,
const array & updates,
int axis,
StreamOrDevice s = {} )
inline

◆ scatter_prod() [2/2]

array mlx::core::scatter_prod ( const array & a,
const std::vector< array > & indices,
const array & updates,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Scatter and prod updates to given indices.

◆ sigmoid()

array mlx::core::sigmoid ( const array & a,
StreamOrDevice s = {} )

Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).

◆ sign()

array mlx::core::sign ( const array & a,
StreamOrDevice s = {} )

The sign of the elements in an array.

◆ sin()

array mlx::core::sin ( const array & a,
StreamOrDevice s = {} )

Sine of the elements of an array.

◆ sinh()

array mlx::core::sinh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Sine of the elements of an array.

◆ slice() [1/2]

array mlx::core::slice ( const array & a,
const std::vector< int > & start,
const std::vector< int > & stop,
StreamOrDevice s = {} )

Slice an array with a stride of 1 in each dimension.

◆ slice() [2/2]

array mlx::core::slice ( const array & a,
std::vector< int > start,
std::vector< int > stop,
std::vector< int > strides,
StreamOrDevice s = {} )

Slice an array.

◆ slice_update() [1/2]

array mlx::core::slice_update ( const array & src,
const array & update,
std::vector< int > start,
std::vector< int > stop,
std::vector< int > strides,
StreamOrDevice s = {} )

Update a slice from the source array.

◆ slice_update() [2/2]

array mlx::core::slice_update ( const array & src,
const array & update,
std::vector< int > start,
std::vector< int > stop,
StreamOrDevice s = {} )

Update a slice from the source array with stride 1 in each dimension.

◆ softmax() [1/3]

array mlx::core::softmax ( const array & a,
bool precise = false,
StreamOrDevice s = {} )

Softmax of an array.

◆ softmax() [2/3]

array mlx::core::softmax ( const array & a,
const std::vector< int > & axes,
bool precise = false,
StreamOrDevice s = {} )

Softmax of an array.

◆ softmax() [3/3]

array mlx::core::softmax ( const array & a,
int axis,
bool precise = false,
StreamOrDevice s = {} )
inline

Softmax of an array.

◆ sort() [1/2]

array mlx::core::sort ( const array & a,
int axis,
StreamOrDevice s = {} )

Returns a sorted copy of the array along a given axis.

◆ sort() [2/2]

array mlx::core::sort ( const array & a,
StreamOrDevice s = {} )

Returns a sorted copy of the flattened array.

◆ split() [1/4]

std::vector< array > mlx::core::split ( const array & a,
const std::vector< int > & indices,
int axis,
StreamOrDevice s = {} )

◆ split() [2/4]

std::vector< array > mlx::core::split ( const array & a,
const std::vector< int > & indices,
StreamOrDevice s = {} )

◆ split() [3/4]

std::vector< array > mlx::core::split ( const array & a,
int num_splits,
int axis,
StreamOrDevice s = {} )

Split an array into sub-arrays along a given axis.

◆ split() [4/4]

std::vector< array > mlx::core::split ( const array & a,
int num_splits,
StreamOrDevice s = {} )

◆ sqrt()

array mlx::core::sqrt ( const array & a,
StreamOrDevice s = {} )

Square root the elements of an array.

◆ square()

array mlx::core::square ( const array & a,
StreamOrDevice s = {} )

Square the elements of an array.

◆ squeeze() [1/3]

array mlx::core::squeeze ( const array & a,
const std::vector< int > & axes,
StreamOrDevice s = {} )

Remove singleton dimensions at the given axes.

◆ squeeze() [2/3]

array mlx::core::squeeze ( const array & a,
int axis,
StreamOrDevice s = {} )
inline

Remove singleton dimensions at the given axis.

◆ squeeze() [3/3]

array mlx::core::squeeze ( const array & a,
StreamOrDevice s = {} )

Remove all singleton dimensions.

◆ stack() [1/2]

array mlx::core::stack ( const std::vector< array > & arrays,
int axis,
StreamOrDevice s = {} )

Stack arrays along a new axis.

◆ stack() [2/2]

array mlx::core::stack ( const std::vector< array > & arrays,
StreamOrDevice s = {} )

◆ std() [1/4]

array mlx::core::std ( const array & a,
bool keepdims,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviation of the elements of an array.

◆ std() [2/4]

array mlx::core::std ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviatoin of the elements of an array along the given axes.

◆ std() [3/4]

array mlx::core::std ( const array & a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the standard deviation of the elements of an array along the given axis.

◆ std() [4/4]

array mlx::core::std ( const array & a,
StreamOrDevice s = {} )
inline

◆ stop_gradient()

array mlx::core::stop_gradient ( const array & a,
StreamOrDevice s = {} )

Stop the flow of gradients.

◆ subtract()

array mlx::core::subtract ( const array & a,
const array & b,
StreamOrDevice s = {} )

Subtract two arrays.

◆ sum() [1/4]

array mlx::core::sum ( const array & a,
bool keepdims,
StreamOrDevice s = {} )

Sums the elements of an array.

◆ sum() [2/4]

array mlx::core::sum ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
StreamOrDevice s = {} )

Sums the elements of an array along the given axes.

◆ sum() [3/4]

array mlx::core::sum ( const array & a,
int axis,
bool keepdims = false,
StreamOrDevice s = {} )

Sums the elements of an array along the given axis.

◆ sum() [4/4]

array mlx::core::sum ( const array & a,
StreamOrDevice s = {} )
inline

◆ swapaxes()

array mlx::core::swapaxes ( const array & a,
int axis1,
int axis2,
StreamOrDevice s = {} )

Swap two axes of an array.

◆ take() [1/2]

array mlx::core::take ( const array & a,
const array & indices,
int axis,
StreamOrDevice s = {} )

Take array slices at the given indices of the specified axis.

◆ take() [2/2]

array mlx::core::take ( const array & a,
const array & indices,
StreamOrDevice s = {} )

Take array entries at the given indices treating the array as flattened.

◆ take_along_axis()

array mlx::core::take_along_axis ( const array & a,
const array & indices,
int axis,
StreamOrDevice s = {} )

Take array entries given indices along the axis.

◆ tan()

array mlx::core::tan ( const array & a,
StreamOrDevice s = {} )

Tangent of the elements of an array.

◆ tanh()

array mlx::core::tanh ( const array & a,
StreamOrDevice s = {} )

Hyperbolic Tangent of the elements of an array.

◆ tensordot() [1/2]

array mlx::core::tensordot ( const array & a,
const array & b,
const int axis = 2,
StreamOrDevice s = {} )

Returns a contraction of a and b over multiple dimensions.

◆ tensordot() [2/2]

array mlx::core::tensordot ( const array & a,
const array & b,
const std::vector< int > & axes_a,
const std::vector< int > & axes_b,
StreamOrDevice s = {} )

◆ tile()

array mlx::core::tile ( const array & arr,
std::vector< int > reps,
StreamOrDevice s = {} )

◆ topk() [1/2]

array mlx::core::topk ( const array & a,
int k,
int axis,
StreamOrDevice s = {} )

Returns topk elements of the array along a given axis.

◆ topk() [2/2]

array mlx::core::topk ( const array & a,
int k,
StreamOrDevice s = {} )

Returns topk elements of the flattened array.

◆ trace() [1/3]

array mlx::core::trace ( const array & a,
int offset,
int axis1,
int axis2,
Dtype dtype,
StreamOrDevice s = {} )

Return the sum along a specified diagonal in the given array.

◆ trace() [2/3]

array mlx::core::trace ( const array & a,
int offset,
int axis1,
int axis2,
StreamOrDevice s = {} )

◆ trace() [3/3]

array mlx::core::trace ( const array & a,
StreamOrDevice s = {} )

◆ transpose() [1/3]

array mlx::core::transpose ( const array & a,
std::initializer_list< int > axes,
StreamOrDevice s = {} )
inline

◆ transpose() [2/3]

array mlx::core::transpose ( const array & a,
std::vector< int > axes,
StreamOrDevice s = {} )

Permutes the dimensions according to the given axes.

◆ transpose() [3/3]

array mlx::core::transpose ( const array & a,
StreamOrDevice s = {} )

Permutes the dimensions in reverse order.

◆ tri() [1/2]

array mlx::core::tri ( int n,
Dtype type,
StreamOrDevice s = {} )
inline

◆ tri() [2/2]

array mlx::core::tri ( int n,
int m,
int k,
Dtype type,
StreamOrDevice s = {} )

◆ tril()

array mlx::core::tril ( array x,
int k = 0,
StreamOrDevice s = {} )

◆ triu()

array mlx::core::triu ( array x,
int k = 0,
StreamOrDevice s = {} )

◆ var() [1/4]

array mlx::core::var ( const array & a,
bool keepdims,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array.

◆ var() [2/4]

array mlx::core::var ( const array & a,
const std::vector< int > & axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array along the given axes.

◆ var() [3/4]

array mlx::core::var ( const array & a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {} )

Computes the variance of the elements of an array along the given axis.

◆ var() [4/4]

array mlx::core::var ( const array & a,
StreamOrDevice s = {} )
inline

◆ view()

array mlx::core::view ( const array & a,
const Dtype & dtype,
StreamOrDevice s = {} )

◆ where()

array mlx::core::where ( const array & condition,
const array & x,
const array & y,
StreamOrDevice s = {} )

Select from x or y depending on condition.

◆ zeros() [1/2]

array mlx::core::zeros ( const std::vector< int > & shape,
Dtype dtype,
StreamOrDevice s = {} )

Fill an array of the given shape with zeros.

◆ zeros() [2/2]

array mlx::core::zeros ( const std::vector< int > & shape,
StreamOrDevice s = {} )
inline

◆ zeros_like()

array mlx::core::zeros_like ( const array & a,
StreamOrDevice s = {} )