52    std::vector<int> shape,
 
   53    std::vector<size_t> strides,
 
   62    std::vector<int> shape,
 
   69  return full(std::move(shape), array(val, dtype), 
to_stream(s));
 
 
   94  return eye(n, n, 0, dtype, s);
 
 
  115  return tri(n, n, 0, type, s);
 
 
  137    std::optional<float> scale = std::nullopt,
 
  143    const std::vector<int>& axes,
 
  155    const std::vector<int>& axes,
 
  164    std::vector<int> start,
 
  165    std::vector<int> stop,
 
  166    std::vector<int> strides,
 
  172    std::vector<int> start,
 
  173    std::vector<int> stop,
 
  180    std::vector<int> start,
 
  181    std::vector<int> stop,
 
  182    std::vector<int> strides,
 
  189    std::vector<int> start,
 
  190    std::vector<int> stop,
 
  199    const std::vector<int>& indices,
 
  207    const std::vector<array>& arrays,
 
  209    std::string indexing = 
"xy",
 
  217    const std::optional<array>& a_min = std::nullopt,
 
  218    const std::optional<array>& a_max = std::nullopt,
 
  223    const std::vector<array>& arrays,
 
  242    std::initializer_list<int> axes,
 
  244  return transpose(a, std::vector<int>(axes), s);
 
 
  260    const std::vector<int>& axes,
 
  261    const std::vector<int>& low_pad_size,
 
  262    const std::vector<int>& high_pad_size,
 
  264    const std::string mode = 
"constant",
 
  270    const std::vector<std::pair<int, int>>& pad_width,
 
  272    const std::string mode = 
"constant",
 
  276    const std::pair<int, int>& pad_width,
 
  278    const std::string mode = 
"constant",
 
  284    const std::string mode = 
"constant",
 
  293    const std::vector<int>& shape,
 
  298    const std::vector<array>& inputs,
 
  408    const array& condition,
 
  417    const std::optional<float> posinf = std::nullopt,
 
  418    const std::optional<float> neginf = std::nullopt,
 
  433    bool equal_nan = 
false,
 
  443    bool equal_nan = 
false,
 
  452    const std::vector<int>& axes,
 
  453    bool keepdims = 
false,
 
  463    bool keepdims = 
false,
 
  478    const std::vector<int>& axes,
 
  479    bool keepdims = 
false,
 
  489    bool keepdims = 
false,
 
  501    const std::vector<int>& axes,
 
  502    bool keepdims = 
false,
 
  509    bool keepdims = 
false,
 
  521    const std::vector<int>& axes,
 
  522    bool keepdims = 
false,
 
  529    bool keepdims = 
false,
 
  542    const std::vector<int>& axes,
 
  543    bool keepdims = 
false,
 
  552    bool keepdims = 
false,
 
  566    const std::vector<int>& axes,
 
  567    bool keepdims = 
false,
 
  576    bool keepdims = 
false,
 
  589    const std::vector<int>& axes,
 
  590    bool keepdims = 
false,
 
  597    bool keepdims = 
false,
 
  609    const std::vector<int>& axes,
 
  610    bool keepdims = 
false,
 
  617    bool keepdims = 
false,
 
  629    const std::vector<int>& axes,
 
  630    bool keepdims = 
false,
 
  637    bool keepdims = 
false,
 
  643  return argmin(a, 
false, s);
 
 
  650    bool keepdims = 
false,
 
  656  return argmax(a, 
false, s);
 
 
  663    bool keepdims = 
false,
 
  717    const std::vector<int>& axes,
 
  718    bool keepdims = 
false,
 
  725    bool keepdims = 
false,
 
  909  return round(a, 0, s);
 
 
  918    const std::vector<array>& indices,
 
  919    const std::vector<int>& axes,
 
  920    const std::vector<int>& slice_sizes,
 
  924    const array& indices,
 
  926    const std::vector<int>& slice_sizes,
 
  928  return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
 
 
  934    const array& indices,
 
  946    const array& indices,
 
  953    const array& indices,
 
 1058    const std::vector<array>& indices,
 
 1059    const array& updates,
 
 1060    const std::vector<int>& axes,
 
 1064    const array& indices,
 
 1065    const array& updates,
 
 1068  return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
 
 
 1074    const std::vector<array>& indices,
 
 1075    const array& updates,
 
 1076    const std::vector<int>& axes,
 
 1080    const array& indices,
 
 1081    const array& updates,
 
 1084  return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
 
 
 1090    const std::vector<array>& indices,
 
 1091    const array& updates,
 
 1092    const std::vector<int>& axes,
 
 1096    const array& indices,
 
 1097    const array& updates,
 
 1100  return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
 
 
 1106    const std::vector<array>& indices,
 
 1107    const array& updates,
 
 1108    const std::vector<int>& axes,
 
 1112    const array& indices,
 
 1113    const array& updates,
 
 1116  return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
 
 
 1121    const std::vector<array>& indices,
 
 1122    const array& updates,
 
 1123    const std::vector<int>& axes,
 
 1127    const array& indices,
 
 1128    const array& updates,
 
 1131  return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
 
 
 1143    const std::vector<int>& axes,
 
 1144    bool precise = 
false,
 
 1153  return softmax(a, std::vector<int>{axis}, precise, s);
 
 
 1163    bool reverse = 
false,
 
 1164    bool inclusive = 
true,
 
 1171    bool reverse = 
false,
 
 1172    bool inclusive = 
true,
 
 1179    bool reverse = 
false,
 
 1180    bool inclusive = 
true,
 
 1187    bool reverse = 
false,
 
 1188    bool inclusive = 
true,
 
 1195    std::vector<int> stride = {},
 
 1196    std::vector<int> padding_lo = {},
 
 1197    std::vector<int> padding_hi = {},
 
 1198    std::vector<int> kernel_dilation = {},
 
 1199    std::vector<int> input_dilation = {},
 
 1207    const array& weight,
 
 1208    std::vector<int> stride = {},
 
 1209    std::vector<int> padding = {},
 
 1210    std::vector<int> kernel_dilation = {},
 
 1211    std::vector<int> input_dilation = {},
 
 
 1231    const array& weight,
 
 1241    const array& weight,
 
 1242    const std::pair<int, int>& stride = {1, 1},
 
 1243    const std::pair<int, int>& padding = {0, 0},
 
 1244    const std::pair<int, int>& dilation = {1, 1},
 
 1251    const array& weight,
 
 1252    const std::tuple<int, int, int>& stride = {1, 1, 1},
 
 1253    const std::tuple<int, int, int>& padding = {0, 0, 0},
 
 1254    const std::tuple<int, int, int>& dilation = {1, 1, 1},
 
 1261    const array& weight,
 
 1271    const array& weight,
 
 1272    const std::pair<int, int>& stride = {1, 1},
 
 1273    const std::pair<int, int>& padding = {0, 0},
 
 1274    const std::pair<int, int>& dilation = {1, 1},
 
 1281    const array& weight,
 
 1282    const std::tuple<int, int, int>& stride = {1, 1, 1},
 
 1283    const std::tuple<int, int, int>& padding = {0, 0, 0},
 
 1284    const std::tuple<int, int, int>& dilation = {1, 1, 1},
 
 1295    int group_size = 64,
 
 1302    int group_size = 64,
 
 1309    const array& scales,
 
 1310    const array& biases,
 
 1311    int group_size = 64,
 
 1319    const array& scales,
 
 1320    const array& biases,
 
 1321    std::optional<array> lhs_indices = std::nullopt,
 
 1322    std::optional<array> rhs_indices = std::nullopt,
 
 1324    int group_size = 64,
 
 1338    const std::vector<int>& axes_a,
 
 1339    const std::vector<int>& axes_b,
 
 1353    const float& alpha = 1.f,
 
 1354    const float& beta = 1.f,
 
 1362    std::optional<array> mask_out = std::nullopt,
 
 1363    std::optional<array> mask_lhs = std::nullopt,
 
 1364    std::optional<array> mask_rhs = std::nullopt,
 
 1371    std::optional<array> lhs_indices = std::nullopt,
 
 1372    std::optional<array> rhs_indices = std::nullopt,
 
 1408    const std::vector<array>& inputs,
 
 1409    const std::vector<array>& dependencies);
 
 1414    const std::vector<array>& a,
 
 1418    const std::vector<array>& a,
 
 1422    const std::vector<array>& a,
 
 1431    std::vector<int> axes,
 
 1464    const std::vector<int>& shift,
 
 1470    const std::vector<int>& axes,
 
 1474    const std::vector<int>& shift,
 
 1479    const std::vector<int>& shift,
 
 1480    const std::vector<int>& axes,
 
array scatter_max(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and max updates to given linear indices.
 
array floor_divide(const array &a, const array &b, StreamOrDevice s={})
Compute integer division.
 
array radians(const array &a, StreamOrDevice s={})
Convert the elements of an array from Degrees to Radians.
 
array arccos(const array &a, StreamOrDevice s={})
Arc Cosine of the elements of an array.
 
array scatter_min(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and min updates to given linear indices.
 
array less_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a <= b) element-wise.
 
array cumprod(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative product of an array.
 
array astype(array a, Dtype dtype, StreamOrDevice s={})
Convert an array to the given data type.
 
array rsqrt(const array &a, StreamOrDevice s={})
Square root and reciprocal the elements of an array.
 
array diag(const array &a, int k=0, StreamOrDevice s={})
Extract diagonal from a 2d array or create a diagonal matrix.
 
array square(const array &a, StreamOrDevice s={})
Square the elements of an array.
 
array ceil(const array &a, StreamOrDevice s={})
Ceil the element of an array.
 
array log2(const array &a, StreamOrDevice s={})
Log base 2 of the elements of an array.
 
array clip(const array &a, const std::optional< array > &a_min=std::nullopt, const std::optional< array > &a_max=std::nullopt, StreamOrDevice s={})
Clip (limit) the values in an array.
 
array isnan(const array &a, StreamOrDevice s={})
 
array isneginf(const array &a, StreamOrDevice s={})
 
array subtract(const array &a, const array &b, StreamOrDevice s={})
Subtract two arrays.
 
array cummin(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative min of an array.
 
array log10(const array &a, StreamOrDevice s={})
Log base 10 of the elements of an array.
 
array log1p(const array &a, StreamOrDevice s={})
Natural logarithm of one plus elements in the array: log(1 + a).
 
array sign(const array &a, StreamOrDevice s={})
The sign of the elements in an array.
 
array cosh(const array &a, StreamOrDevice s={})
Hyperbolic Cosine of the elements of an array.
 
array conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})
General convolution with a filter.
 
array logical_or(const array &a, const array &b, StreamOrDevice s={})
Logical or of two arrays.
 
array moveaxis(const array &a, int source, int destination, StreamOrDevice s={})
Move an axis of an array.
 
array operator*(const array &a, const array &b)
 
array operator+(const array &a, const array &b)
 
array operator||(const array &a, const array &b)
 
array not_equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a != b) element-wise.
 
array erf(const array &a, StreamOrDevice s={})
Computes the error function of the elements of an array.
 
array sqrt(const array &a, StreamOrDevice s={})
Square root the elements of an array.
 
array std(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the standard deviation of the elements of an array.
 
array add(const array &a, const array &b, StreamOrDevice s={})
Add two arrays.
 
array round(const array &a, int decimals, StreamOrDevice s={})
Round a floating point number.
 
array conv1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D convolution with a filter
 
array bitwise_xor(const array &a, const array &b, StreamOrDevice s={})
Bitwise exclusive or.
 
array equal(const array &a, const array &b, StreamOrDevice s={})
Returns the bool array with (a == b) element-wise.
 
array zeros(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with zeros.
 
array view(const array &a, const Dtype &dtype, StreamOrDevice s={})
 
array gather_qmm(const array &x, const array &w, const array &scales, const array &biases, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Compute matrix products with matrix-level gather.
 
array stop_gradient(const array &a, StreamOrDevice s={})
Stop the flow of gradients.
 
array scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and prod updates to given indices.
 
array slice_update(const array &src, const array &update, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Update a slice from the source array.
 
array cos(const array &a, StreamOrDevice s={})
Cosine of the elements of an array.
 
array operator>=(const array &a, const array &b)
Definition ops.h:345
 
array degrees(const array &a, StreamOrDevice s={})
Convert the elements of an array from Radians to Degrees.
 
array all(const array &a, bool keepdims, StreamOrDevice s={})
True if all elements in the array are true (or non-zero).
 
array tan(const array &a, StreamOrDevice s={})
Tangent of the elements of an array.
 
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape (n,m) with ones in the specified diagonal k, and zeros everywhere el...
 
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
array operator>>(const array &a, const array &b)
 
array minimum(const array &a, const array &b, StreamOrDevice s={})
Element-wise minimum between two arrays.
 
array prod(const array &a, bool keepdims, StreamOrDevice s={})
The product of all elements of the array.
 
array atleast_3d(const array &a, StreamOrDevice s={})
 
array operator<=(const array &a, const array &b)
Definition ops.h:373
 
array reciprocal(const array &a, StreamOrDevice s={})
The reciprocal (1/x) of the elements in an array.
 
array tri(int n, int m, int k, Dtype type, StreamOrDevice s={})
 
array flatten(const array &a, int start_axis, int end_axis=-1, StreamOrDevice s={})
Flatten the dimensions in the range [start_axis, end_axis] .
 
array isclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
Returns a boolean array where two arrays are element-wise equal within the specified tolerance.
 
array operator|(const array &a, const array &b)
 
array topk(const array &a, int k, StreamOrDevice s={})
Returns topk elements of the flattened array.
 
array expm1(const array &a, StreamOrDevice s={})
Computes the expm1 function of the elements of an array.
 
array ones(const std::vector< int > &shape, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with ones.
 
array abs(const array &a, StreamOrDevice s={})
Absolute value of elements in an array.
 
std::vector< array > meshgrid(const std::vector< array > &arrays, bool sparse=false, std::string indexing="xy", StreamOrDevice s={})
A vector of coordinate arrays from coordinate vectors.
 
array conjugate(const array &a, StreamOrDevice s={})
 
array tanh(const array &a, StreamOrDevice s={})
Hyperbolic Tangent of the elements of an array.
 
array inner(const array &a, const array &b, StreamOrDevice s={})
Compute the inner product of two vectors.
 
array block_masked_mm(array a, array b, int block_size, std::optional< array > mask_out=std::nullopt, std::optional< array > mask_lhs=std::nullopt, std::optional< array > mask_rhs=std::nullopt, StreamOrDevice s={})
Compute matrix product with block masking.
 
array arctan2(const array &a, const array &b, StreamOrDevice s={})
Inverse tangent of the ratio of two arrays.
 
array number_of_elements(const array &a, std::vector< int > axes, bool inverted, Dtype dtype=int32, StreamOrDevice s={})
Extract the number of elements along some axes as a scalar array.
 
array conv3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D convolution with a filter
 
array log(const array &a, StreamOrDevice s={})
Natural logarithm of the elements of an array.
 
array sigmoid(const array &a, StreamOrDevice s={})
Element-wise logistic sigmoid of the array: 1 / (1 + exp(-x).
 
array squeeze(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Remove singleton dimensions at the given axes.
 
array greater_equal(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a >= b) element-wise.
 
array expand_dims(const array &a, const std::vector< int > &axes, StreamOrDevice s={})
Add a singleton dimension at the given axes.
 
array isfinite(const array &a, StreamOrDevice s={})
 
array conv2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D convolution with a filter
 
array operator>(const array &a, const array &b)
Definition ops.h:331
 
array bitwise_and(const array &a, const array &b, StreamOrDevice s={})
Bitwise and.
 
std::vector< array > split(const array &a, int num_splits, int axis, StreamOrDevice s={})
Split an array into sub-arrays along a given axis.
 
array matmul(const array &a, const array &b, StreamOrDevice s={})
Matrix-matrix multiplication.
 
array logical_and(const array &a, const array &b, StreamOrDevice s={})
Logical and of two arrays.
 
array erfinv(const array &a, StreamOrDevice s={})
Computes the inverse error function of the elements of an array.
 
array divide(const array &a, const array &b, StreamOrDevice s={})
Divide two arrays.
 
array power(const array &a, const array &b, StreamOrDevice s={})
Raise elements of a to the power of b element-wise.
 
array maximum(const array &a, const array &b, StreamOrDevice s={})
Element-wise maximum between two arrays.
 
array reshape(const array &a, std::vector< int > shape, StreamOrDevice s={})
Reshape an array to the given shape.
 
array argmin(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the minimum value in the array.
 
array var(const array &a, bool keepdims, int ddof=0, StreamOrDevice s={})
Computes the variance of the elements of an array.
 
array full(std::vector< int > shape, array vals, Dtype dtype, StreamOrDevice s={})
Fill an array of the given shape with the given value(s).
 
array softmax(const array &a, const std::vector< int > &axes, bool precise=false, StreamOrDevice s={})
Softmax of an array.
 
array sort(const array &a, StreamOrDevice s={})
Returns a sorted copy of the flattened array.
 
array max(const array &a, bool keepdims, StreamOrDevice s={})
The maximum of all elements of the array.
 
array imag(const array &a, StreamOrDevice s={})
 
array pad(const array &a, const std::vector< int > &axes, const std::vector< int > &low_pad_size, const std::vector< int > &high_pad_size, const array &pad_value=array(0), const std::string mode="constant", StreamOrDevice s={})
Pad an array with a constant value.
 
array addmm(array c, array a, array b, const float &alpha=1.f, const float &beta=1.f, StreamOrDevice s={})
Compute D = beta * C + alpha * (A @ B)
 
array tril(array x, int k=0, StreamOrDevice s={})
 
array any(const array &a, bool keepdims, StreamOrDevice s={})
True if any elements in the array are true (or non-zero).
 
array outer(const array &a, const array &b, StreamOrDevice s={})
Compute the outer product of two vectors.
 
array hadamard_transform(const array &a, std::optional< float > scale=std::nullopt, StreamOrDevice s={})
Multiply the array by the Hadamard matrix of corresponding size.
 
array arcsin(const array &a, StreamOrDevice s={})
Arc Sine of the elements of an array.
 
array left_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the left.
 
array where(const array &condition, const array &x, const array &y, StreamOrDevice s={})
Select from x or y depending on condition.
 
array exp(const array &a, StreamOrDevice s={})
Exponential of the elements of an array.
 
array bitwise_or(const array &a, const array &b, StreamOrDevice s={})
Bitwise inclusive or.
 
array gather_mm(array a, array b, std::optional< array > lhs_indices=std::nullopt, std::optional< array > rhs_indices=std::nullopt, StreamOrDevice s={})
Compute matrix product with matrix-level gather.
 
array floor(const array &a, StreamOrDevice s={})
Floor the element of an array.
 
array conv_transpose3d(const array &input, const array &weight, const std::tuple< int, int, int > &stride={1, 1, 1}, const std::tuple< int, int, int > &padding={0, 0, 0}, const std::tuple< int, int, int > &dilation={1, 1, 1}, int groups=1, StreamOrDevice s={})
3D transposed convolution with a filter
 
array as_strided(array a, std::vector< int > shape, std::vector< size_t > strides, size_t offset, StreamOrDevice s={})
Create a view of an array with the given shape and strides.
 
array argsort(const array &a, StreamOrDevice s={})
Returns indices that sort the flattened array.
 
array put_along_axis(const array &a, const array &indices, const array &values, int axis, StreamOrDevice s={})
Put the values into the array at the given indices along the axis.
 
array array_equal(const array &a, const array &b, bool equal_nan, StreamOrDevice s={})
True if two arrays have the same shape and elements.
 
array isinf(const array &a, StreamOrDevice s={})
 
array less(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a < b) element-wise.
 
array diagonal(const array &a, int offset=0, int axis1=0, int axis2=1, StreamOrDevice s={})
Extract a diagonal or construct a diagonal array.
 
array ones_like(const array &a, StreamOrDevice s={})
 
array negative(const array &a, StreamOrDevice s={})
Negate an array.
 
array linspace(double start, double stop, int num=50, Dtype dtype=float32, StreamOrDevice s={})
A 1D array of num evenly spaced numbers in the range [start, stop]
 
array remainder(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise remainder of division.
 
array arctan(const array &a, StreamOrDevice s={})
Arc Tangent of the elements of an array.
 
array conv_transpose1d(const array &input, const array &weight, int stride=1, int padding=0, int dilation=1, int groups=1, StreamOrDevice s={})
1D transposed convolution with a filter
 
std::vector< array > divmod(const array &a, const array &b, StreamOrDevice s={})
Compute the element-wise quotient and remainder.
 
array triu(array x, int k=0, StreamOrDevice s={})
 
array arccosh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Cosine of the elements of an array.
 
array tile(const array &arr, std::vector< int > reps, StreamOrDevice s={})
 
array nan_to_num(const array &a, float nan=0.0f, const std::optional< float > posinf=std::nullopt, const std::optional< float > neginf=std::nullopt, StreamOrDevice s={})
Replace NaN and infinities with finite numbers.
 
array min(const array &a, bool keepdims, StreamOrDevice s={})
The minimum of all elements of the array.
 
array operator%(const array &a, const array &b)
 
std::tuple< array, array, array > quantize(const array &w, int group_size=64, int bits=4, StreamOrDevice s={})
Quantize a matrix along its last axis.
 
array arctanh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Tangent of the elements of an array.
 
array repeat(const array &arr, int repeats, int axis, StreamOrDevice s={})
Repeat an array along an axis.
 
array gather(const array &a, const std::vector< array > &indices, const std::vector< int > &axes, const std::vector< int > &slice_sizes, StreamOrDevice s={})
Gather array entries given indices and slices.
 
std::vector< array > broadcast_arrays(const std::vector< array > &inputs, StreamOrDevice s={})
Broadcast a vector of arrays against one another.
 
array atleast_1d(const array &a, StreamOrDevice s={})
convert an array to an atleast ndim array
 
array swapaxes(const array &a, int axis1, int axis2, StreamOrDevice s={})
Swap two axes of an array.
 
array logical_not(const array &a, StreamOrDevice s={})
Logical not of an array.
 
array concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Concatenate arrays along a given axis.
 
array trace(const array &a, int offset, int axis1, int axis2, Dtype dtype, StreamOrDevice s={})
Return the sum along a specified diagonal in the given array.
 
array quantized_matmul(array x, array w, array scales, array biases, bool transpose=true, int group_size=64, int bits=4, StreamOrDevice s={})
Quantized matmul multiplies x with a quantized matrix w.
 
array dequantize(const array &w, const array &scales, const array &biases, int group_size=64, int bits=4, StreamOrDevice s={})
Dequantize a matrix produced by quantize()
 
array transpose(const array &a, std::vector< int > axes, StreamOrDevice s={})
Permutes the dimensions according to the given axes.
 
array partition(const array &a, int kth, StreamOrDevice s={})
Returns a partitioned copy of the flattened array such that the smaller kth elements are first.
 
array take(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array slices at the given indices of the specified axis.
 
array operator^(const array &a, const array &b)
 
array roll(const array &a, int shift, StreamOrDevice s={})
Roll elements along an axis and introduce them on the other side.
 
std::vector< array > depends(const std::vector< array > &inputs, const std::vector< array > &dependencies)
Implements the identity function but allows injecting dependencies to other arrays.
 
array arcsinh(const array &a, StreamOrDevice s={})
Inverse Hyperbolic Sine of the elements of an array.
 
array scatter_add(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter and add updates to given indices.
 
array logsumexp(const array &a, bool keepdims, StreamOrDevice s={})
The logsumexp of all elements of the array.
 
array broadcast_to(const array &a, const std::vector< int > &shape, StreamOrDevice s={})
Broadcast an array to a given shape.
 
array scatter(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})
Scatter updates to the given indices.
 
array operator<<(const array &a, const array &b)
 
array slice(const array &a, std::vector< int > start, std::vector< int > stop, std::vector< int > strides, StreamOrDevice s={})
Slice an array.
 
array isposinf(const array &a, StreamOrDevice s={})
 
array cumsum(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative sum of an array.
 
array operator-(const array &a)
 
array mean(const array &a, bool keepdims, StreamOrDevice s={})
Computes the mean of the elements of an array.
 
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
 
array take_along_axis(const array &a, const array &indices, int axis, StreamOrDevice s={})
Take array entries given indices along the axis.
 
array argmax(const array &a, bool keepdims, StreamOrDevice s={})
Returns the index of the maximum value in the array.
 
array conv_transpose2d(const array &input, const array &weight, const std::pair< int, int > &stride={1, 1}, const std::pair< int, int > &padding={0, 0}, const std::pair< int, int > &dilation={1, 1}, int groups=1, StreamOrDevice s={})
2D transposed convolution with a filter
 
array sin(const array &a, StreamOrDevice s={})
Sine of the elements of an array.
 
array operator&&(const array &a, const array &b)
 
array cummax(const array &a, int axis, bool reverse=false, bool inclusive=true, StreamOrDevice s={})
Cumulative max of an array.
 
array operator<(const array &a, const array &b)
Definition ops.h:359
 
array atleast_2d(const array &a, StreamOrDevice s={})
 
array operator/(const array &a, const array &b)
 
array allclose(const array &a, const array &b, double rtol=1e-5, double atol=1e-8, bool equal_nan=false, StreamOrDevice s={})
True if the two arrays are equal within the specified tolerance.
 
array operator&(const array &a, const array &b)
 
array argpartition(const array &a, int kth, StreamOrDevice s={})
Returns indices that partition the flattened array such that the smaller kth elements are first.
 
array greater(const array &a, const array &b, StreamOrDevice s={})
Returns bool array with (a > b) element-wise.
 
array sinh(const array &a, StreamOrDevice s={})
Hyperbolic Sine of the elements of an array.
 
array multiply(const array &a, const array &b, StreamOrDevice s={})
Multiply two arrays.
 
array tensordot(const array &a, const array &b, const int axis=2, StreamOrDevice s={})
Returns a contraction of a and b over multiple dimensions.
 
array real(const array &a, StreamOrDevice s={})
 
array stack(const std::vector< array > &arrays, int axis, StreamOrDevice s={})
Stack arrays along a new axis.
 
array logaddexp(const array &a, const array &b, StreamOrDevice s={})
Log-add-exp of one elements in the array: log(exp(a) + exp(b)).
 
array right_shift(const array &a, const array &b, StreamOrDevice s={})
Shift bits to the right.
 
array zeros_like(const array &a, StreamOrDevice s={})
 
void arange(const std::vector< array > &inputs, array &out, double start, double step)
Definition arange.h:24
 
Stream to_stream(StreamOrDevice s)
 
void copy(const array &src, array &dst, CopyType ctype)
 
constexpr Dtype int32
Definition dtype.h:76
 
constexpr Dtype float32
Definition dtype.h:80
 
bool operator==(const Device &lhs, const Device &rhs)
 
bool operator!=(const Device &lhs, const Device &rhs)
 
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:14