mlx/mlx/ops.h
Cyril Zakka, MD e080290ba4
Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00

970 lines
28 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "load.h"
#include "stream.h"
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
/** Creation operations */
/**
* A 1D array of numbers starting at `start` (optional),
* stopping at stop, stepping by `step` (optional). **/
array arange(
double start,
double stop,
double step,
Dtype dtype,
StreamOrDevice s = {});
array arange(double start, double stop, double step, StreamOrDevice s = {});
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double start, double stop, StreamOrDevice s = {});
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double stop, StreamOrDevice s = {});
array arange(int start, int stop, int step, StreamOrDevice s = {});
array arange(int start, int stop, StreamOrDevice s = {});
array arange(int stop, StreamOrDevice s = {});
/** Convert an array to the given data type. */
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
const array& a,
std::vector<int> shape,
std::vector<size_t> strides,
size_t offset,
StreamOrDevice s = {});
/** Copy another array. */
array copy(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
const std::vector<int>& shape,
const array& vals,
Dtype dtype,
StreamOrDevice s = {});
array full(
const std::vector<int>& shape,
const array& vals,
StreamOrDevice s = {});
template <typename T>
array full(
const std::vector<int>& shape,
T val,
Dtype dtype,
StreamOrDevice s = {}) {
return full(shape, array(val, dtype), to_stream(s));
}
template <typename T>
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
return full(shape, array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
return zeros(shape, float32, s);
}
array zeros_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
return ones(shape, float32, s);
}
array ones_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
* k, and zeros everywhere else. */
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
return eye(n, n, 0, dtype, s);
}
inline array eye(int n, int m, StreamOrDevice s = {}) {
return eye(n, m, 0, float32, s);
}
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
return eye(n, m, k, float32, s);
}
inline array eye(int n, StreamOrDevice s = {}) {
return eye(n, n, 0, float32, s);
}
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
* diagonal. */
array identity(int n, Dtype dtype, StreamOrDevice s = {});
inline array identity(int n, StreamOrDevice s = {}) {
return identity(n, float32, s);
}
/** array manipulation */
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
/** Remove singleton dimensions at the given axes. */
array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
return squeeze(a, std::vector<int>{axis}, s);
}
/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});
/** Add a singleton dimension at the given axes. */
array expand_dims(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Add a singleton dimension at the given axis. */
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
return expand_dims(a, std::vector<int>{axis}, s);
}
/** Slice an array. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s = {});
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
const std::vector<int>& start,
const std::vector<int>& stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
int axis,
StreamOrDevice s = {});
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
/** Concatenate arrays along a given axis. */
array concatenate(
const std::vector<array>& arrays,
int axis,
StreamOrDevice s = {});
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(
const array& a,
std::initializer_list<int> axes,
StreamOrDevice s = {}) {
return transpose(a, std::vector<int>(axes), s);
}
/** Pad an array with a constant value */
array pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const array& pad_value = array(0),
StreamOrDevice s = {});
/** Pad an array with a constant value along all axes */
array pad(
const array& a,
const std::vector<std::pair<int, int>>& pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
array pad(
const array& a,
const std::pair<int, int>& pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
array pad(
const array& a,
int pad_width,
const array& pad_value = array(0),
StreamOrDevice s = {});
/** Permutes the dimensions in reverse order. */
array transpose(const array& a, StreamOrDevice s = {});
/** Broadcast an array to a given shape. */
array broadcast_to(
const array& a,
const std::vector<int>& shape,
StreamOrDevice s = {});
/** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s = {});
/** Comparison operations */
/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
return equal(a, b);
}
template <typename T>
array operator==(T a, const array& b) {
return equal(array(a), b);
}
template <typename T>
array operator==(const array& a, T b) {
return equal(a, array(b));
}
/** Returns the bool array with (a != b) element-wise. */
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator!=(const array& a, const array& b) {
return not_equal(a, b);
}
template <typename T>
array operator!=(T a, const array& b) {
return not_equal(array(a), b);
}
template <typename T>
array operator!=(const array& a, T b) {
return not_equal(a, array(b));
}
/** Returns bool array with (a > b) element-wise. */
array greater(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>(const array& a, const array& b) {
return greater(a, b);
}
template <typename T>
array operator>(T a, const array& b) {
return greater(array(a), b);
}
template <typename T>
array operator>(const array& a, T b) {
return greater(a, array(b));
}
/** Returns bool array with (a >= b) element-wise. */
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>=(const array& a, const array& b) {
return greater_equal(a, b);
}
template <typename T>
array operator>=(T a, const array& b) {
return greater_equal(array(a), b);
}
template <typename T>
array operator>=(const array& a, T b) {
return greater_equal(a, array(b));
}
/** Returns bool array with (a < b) element-wise. */
array less(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<(const array& a, const array& b) {
return less(a, b);
}
template <typename T>
array operator<(T a, const array& b) {
return less(array(a), b);
}
template <typename T>
array operator<(const array& a, T b) {
return less(a, array(b));
}
/** Returns bool array with (a <= b) element-wise. */
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<=(const array& a, const array& b) {
return less_equal(a, b);
}
template <typename T>
array operator<=(T a, const array& b) {
return less_equal(array(a), b);
}
template <typename T>
array operator<=(const array& a, T b) {
return less_equal(a, array(b));
}
/** True if two arrays have the same shape and elements. */
array array_equal(
const array& a,
const array& b,
bool equal_nan,
StreamOrDevice s = {});
inline array
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
return array_equal(a, b, false, s);
}
/** Select from x or y depending on condition. */
array where(
const array& condition,
const array& x,
const array& y,
StreamOrDevice s = {});
/** Reduction operations */
/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
return all(a, false, to_stream(s));
}
/** True if the two arrays are equal within the specified tolerance. */
array allclose(
const array& a,
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
StreamOrDevice s = {});
/**
* Reduces the input along the given axes. An output value is true
* if all the corresponding inputs are true.
**/
array all(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/**
* Reduces the input along the given axis. An output value is true
* if all the corresponding inputs are true.
**/
array all(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** True if any elements in the array are true (or non-zero). **/
array any(const array& a, bool keepdims, StreamOrDevice s = {});
inline array any(const array& a, StreamOrDevice s = {}) {
return any(a, false, to_stream(s));
}
/**
* Reduces the input along the given axes. An output value is true
* if any of the corresponding inputs are true.
**/
array any(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/**
* Reduces the input along the given axis. An output value is true
* if any of the corresponding inputs are true.
**/
array any(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Sums the elements of an array. */
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
inline array sum(const array& a, StreamOrDevice s = {}) {
return sum(a, false, to_stream(s));
}
/** Sums the elements of an array along the given axes. */
array sum(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Sums the elements of an array along the given axis. */
array sum(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
inline array mean(const array& a, StreamOrDevice s = {}) {
return mean(a, false, to_stream(s));
}
/** Computes the mean of the elements of an array along the given axes */
array mean(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array along the given axis */
array mean(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Computes the mean of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
return var(a, false, 0, to_stream(s));
}
/** Computes the var of the elements of an array along the given axes */
array var(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** Computes the var of the elements of an array along the given axis */
array var(
const array& a,
int axis,
bool keepdims = false,
int ddof = 0,
StreamOrDevice s = {});
/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
return prod(a, false, to_stream(s));
}
/** The product of the elements of an array along the given axes. */
array prod(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The product of the elements of an array along the given axis. */
array prod(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** The maximum of all elements of the array. */
array max(const array& a, bool keepdims, StreamOrDevice s = {});
inline array max(const array& a, StreamOrDevice s = {}) {
return max(a, false, to_stream(s));
}
/** The maximum of the elements of an array along the given axes. */
array max(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The maximum of the elements of an array along the given axis. */
array max(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** The minimum of all elements of the array. */
array min(const array& a, bool keepdims, StreamOrDevice s = {});
inline array min(const array& a, StreamOrDevice s = {}) {
return min(a, false, to_stream(s));
}
/** The minimum of the elements of an array along the given axes. */
array min(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The minimum of the elements of an array along the given axis. */
array min(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns the index of the minimum value in the array. */
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmin(const array& a, StreamOrDevice s = {}) {
return argmin(a, false, s);
}
/** Returns the indices of the minimum values along a given axis. */
array argmin(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns the index of the maximum value in the array. */
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmax(const array& a, StreamOrDevice s = {}) {
return argmax(a, false, s);
}
/** Returns the indices of the maximum values along a given axis. */
array argmax(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Returns a sorted copy of the flattened array. */
array sort(const array& a, StreamOrDevice s = {});
/** Returns a sorted copy of the array along a given axis. */
array sort(const array& a, int axis, StreamOrDevice s = {});
/** Returns indices that sort the flattened array. */
array argsort(const array& a, StreamOrDevice s = {});
/** Returns indices that sort the array along a given axis. */
array argsort(const array& a, int axis, StreamOrDevice s = {});
/**
* Returns a partitioned copy of the flattened array
* such that the smaller kth elements are first.
**/
array partition(const array& a, int kth, StreamOrDevice s = {});
/**
* Returns a partitioned copy of the array along a given axis
* such that the smaller kth elements are first.
**/
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
/**
* Returns indices that partition the flattened array
* such that the smaller kth elements are first.
**/
array argpartition(const array& a, int kth, StreamOrDevice s = {});
/**
* Returns indices that partition the array along a given axis
* such that the smaller kth elements are first.
**/
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
/** Returns topk elements of the flattened array. */
array topk(const array& a, int k, StreamOrDevice s = {});
/** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
/** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
return logsumexp(a, false, to_stream(s));
}
/** The logsumexp of the elements of an array along the given axes. */
array logsumexp(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
StreamOrDevice s = {});
/** The logsumexp of the elements of an array along the given axis. */
array logsumexp(
const array& a,
int axis,
bool keepdims = false,
StreamOrDevice s = {});
/** Simple arithmetic operations */
/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});
/** Negate an array. */
array negative(const array& a, StreamOrDevice s = {});
array operator-(const array& a);
/** The sign of the elements in an array. */
array sign(const array& a, StreamOrDevice s = {});
/** Logical not of an array */
array logical_not(const array& a, StreamOrDevice s = {});
/** The reciprocal (1/x) of the elements in an array. */
array reciprocal(const array& a, StreamOrDevice s = {});
/** Add two arrays. */
array add(const array& a, const array& b, StreamOrDevice s = {});
array operator+(const array& a, const array& b);
template <typename T>
array operator+(T a, const array& b) {
return add(array(a), b);
}
template <typename T>
array operator+(const array& a, T b) {
return add(a, array(b));
}
/** Subtract two arrays. */
array subtract(const array& a, const array& b, StreamOrDevice s = {});
array operator-(const array& a, const array& b);
template <typename T>
array operator-(T a, const array& b) {
return subtract(array(a), b);
}
template <typename T>
array operator-(const array& a, T b) {
return subtract(a, array(b));
}
/** Multiply two arrays. */
array multiply(const array& a, const array& b, StreamOrDevice s = {});
array operator*(const array& a, const array& b);
template <typename T>
array operator*(T a, const array& b) {
return multiply(array(a), b);
}
template <typename T>
array operator*(const array& a, T b) {
return multiply(a, array(b));
}
/** Divide two arrays. */
array divide(const array& a, const array& b, StreamOrDevice s = {});
array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);
/** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b);
template <typename T>
array operator%(T a, const array& b) {
return remainder(array(a), b);
}
template <typename T>
array operator%(const array& a, T b) {
return remainder(a, array(b));
}
/** Element-wise maximum between two arrays. */
array maximum(const array& a, const array& b, StreamOrDevice s = {});
/** Element-wise minimum between two arrays. */
array minimum(const array& a, const array& b, StreamOrDevice s = {});
/** Square the elements of an array. */
array square(const array& a, StreamOrDevice s = {});
/** Exponential of the elements of an array. */
array exp(const array& a, StreamOrDevice s = {});
/** Sine of the elements of an array */
array sin(const array& a, StreamOrDevice s = {});
/** Cosine of the elements of an array */
array cos(const array& a, StreamOrDevice s = {});
/** Tangent of the elements of an array */
array tan(const array& a, StreamOrDevice s = {});
/** Arc Sine of the elements of an array */
array arcsin(const array& a, StreamOrDevice s = {});
/** Arc Cosine of the elements of an array */
array arccos(const array& a, StreamOrDevice s = {});
/** Arc Tangent of the elements of an array */
array arctan(const array& a, StreamOrDevice s = {});
/** Hyperbolic Sine of the elements of an array */
array sinh(const array& a, StreamOrDevice s = {});
/** Hyperbolic Cosine of the elements of an array */
array cosh(const array& a, StreamOrDevice s = {});
/** Hyperbolic Tangent of the elements of an array */
array tanh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Sine of the elements of an array */
array arcsinh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Cosine of the elements of an array */
array arccosh(const array& a, StreamOrDevice s = {});
/** Inverse Hyperbolic Tangent of the elements of an array */
array arctanh(const array& a, StreamOrDevice s = {});
/** Natural logarithm of the elements of an array. */
array log(const array& a, StreamOrDevice s = {});
/** Log base 2 of the elements of an array. */
array log2(const array& a, StreamOrDevice s = {});
/** Log base 10 of the elements of an array. */
array log10(const array& a, StreamOrDevice s = {});
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
array log1p(const array& a, StreamOrDevice s = {});
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
array sigmoid(const array& a, StreamOrDevice s = {});
/** Computes the error function of the elements of an array. */
array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});
/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});
/** Matrix-matrix multiplication. */
array matmul(const array& a, const array& b, StreamOrDevice s = {});
/** Gather array entries given indices and slices */
array gather(
const array& a,
const std::vector<array>& indices,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes,
StreamOrDevice s = {});
inline array gather(
const array& a,
const array& indices,
int axis,
const std::vector<int>& slice_sizes,
StreamOrDevice s = {}) {
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}
/** Take array slices at the given indices of the specified axis. */
array take(
const array& a,
const array& indices,
int axis,
StreamOrDevice s = {});
/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});
/** Take array entries given indices along the axis */
array take_along_axis(
const array& a,
const array& indices,
int axis,
StreamOrDevice s = {});
/** Scatter updates to given linear indices */
array scatter(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and add updates to given indices */
array scatter_add(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_add(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and prod updates to given indices */
array scatter_prod(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_prod(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and max updates to given linear indices */
array scatter_max(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_max(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and min updates to given linear indices */
array scatter_min(
const array& a,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array scatter_min(
const array& a,
const array& indices,
const array& updates,
int axis,
StreamOrDevice s = {}) {
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Square root the elements of an array. */
array sqrt(const array& a, StreamOrDevice s = {});
/** Square root and reciprocal the elements of an array. */
array rsqrt(const array& a, StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** Softmax of an array. */
array softmax(const array& a, StreamOrDevice s = {});
/** Softmax of an array. */
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
return softmax(a, std::vector<int>{axis}, s);
}
/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});
inline array operator^(const array& a, const array& b) {
return power(a, b);
}
template <typename T>
array operator^(T a, const array& b) {
return power(array(a), b);
}
template <typename T>
array operator^(const array& a, T b) {
return power(a, array(b));
}
/** Cumulative sum of an array. */
array cumsum(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative product of an array. */
array cumprod(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative max of an array. */
array cummax(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Cumulative min of an array. */
array cummin(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
StreamOrDevice s = {});
/** Convolution operations */
/** 1D convolution with a filter */
array conv1d(
const array& input,
const array& weight,
int stride = 1,
int padding = 0,
int dilation = 1,
int groups = 1,
StreamOrDevice s = {});
/** 2D convolution with a filter */
array conv2d(
const array& input,
const array& weight,
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
int groups = 1,
StreamOrDevice s = {});
/** Serialization operations */
/** Save array to out stream in .npy format */
void save(
std::shared_ptr<io::Writer> out_stream,
array a,
bool retain_graph = true);
/** Save array to file in .npy format */
void save(const std::string& file, array a, bool retain_graph = true);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
} // namespace mlx::core