#include #include #include #include #include #include #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" #include "python/src/utils.h" namespace py = pybind11; using namespace py::literals; using namespace mlx::core; using Scalar = std::variant; Dtype scalar_to_dtype(Scalar scalar) { if (std::holds_alternative(scalar)) { return int32; } else { return float32; } } double scalar_to_double(Scalar s) { if (std::holds_alternative(s)) { return std::get(s); } else { return static_cast(std::get(s)); } } void init_ops(py::module_& m) { m.def( "reshape", &reshape, "a"_a, py::pos_only(), "shape"_a, py::kw_only(), "stream"_a = none, R"pbdoc( Reshape an array while preserving the size. Args: a (array): Input array. shape (tuple(int)): New shape. stream (Stream, optional): Stream or device. Defaults to ```None``` in which case the default stream of the default device is used. Returns: array: The reshaped array. )pbdoc"); m.def( "squeeze", [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { if (std::holds_alternative(v)) { return squeeze(a, s); } else if (auto pv = std::get_if(&v); pv) { return squeeze(a, *pv, s); } else { return squeeze(a, std::get>(v), s); } }, "a"_a, py::pos_only(), "axis"_a = none, py::kw_only(), "stream"_a = none, R"pbdoc( Remove length one axes from an array. Args: a (array): Input array. axis (int or tuple(int), optional): Axes to remove. Defaults to ```None``` in which case all size one axes are removed. Returns: array: The output array with size one axes removed. )pbdoc"); m.def( "expand_dims", [](const array& a, const std::variant>& v, StreamOrDevice s) { if (auto pv = std::get_if(&v); pv) { return expand_dims(a, *pv, s); } else { return expand_dims(a, std::get>(v), s); } }, "a"_a, py::pos_only(), "axis"_a, py::kw_only(), "stream"_a = none, R"pbdoc( Add a size one dimension at the given axis. Args: a (array): Input array. axes (int or tuple(int)): The index of the inserted dimensions. Returns: array: The array with inserted dimensions. )pbdoc"); m.def( "abs", &mlx::core::abs, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise absolute value. Args: a (array): Input array. Returns: array: The absolute value of ``a``. )pbdoc"); m.def( "sign", &sign, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise sign. Args: a (array): Input array. Returns: array: The sign of ``a``. )pbdoc"); m.def( "negative", &negative, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise negation. Args: a (array): Input array. Returns: array: The negative of ``a``. )pbdoc"); m.def( "add", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return add(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise addition. Add two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The sum of ``a`` and ``b``. )pbdoc"); m.def( "subtract", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return subtract(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise subtraction. Subtract one array from another with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The difference ``a - b``. )pbdoc"); m.def( "multiply", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return multiply(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise multiplication. Multiply two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The multiplication ``a * b``. )pbdoc"); m.def( "divide", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return divide(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise division. Divide two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The quotient ``a / b``. )pbdoc"); m.def( "equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return equal(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise equality. Equality comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a == b``. )pbdoc"); m.def( "not_equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return not_equal(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise not equal. Not equal comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a != b``. )pbdoc"); m.def( "less", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return less(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise less than. Strict less than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a < b``. )pbdoc"); m.def( "less_equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return less_equal(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise less than or equal. Less than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a <= b``. )pbdoc"); m.def( "greater", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return greater(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise greater than. Strict greater than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a > b``. )pbdoc"); m.def( "greater_equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return greater_equal(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise greater or equal. Greater than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The element-wise comparison ``a >= b``. )pbdoc"); m.def( "array_equal", [](const ScalarOrArray& a_, const ScalarOrArray& b_, bool equal_nan, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return array_equal(a, b, equal_nan, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "equal_nan"_a = false, "stream"_a = none, R"pbdoc( Array equality check. Compare two arrays for equality. Returns ``True`` if and only if the arrays have the same shape and their values are equal. The arrays need not have the same type to be considered equal. Args: a (array): Input array or scalar. b (array): Input array or scalar. equal_nan (bool): If ``True``, NaNs are treated as equal. Defaults to ``False``. Returns: array: A scalar boolean array. )pbdoc"); m.def( "matmul", &matmul, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Matrix multiplication. Perform the (possibly batched) matrix multiplication of two arrays. This function supports broadcasting for arrays with more than two dimensions. - If the first array is 1-D then a 1 is prepended to its shape to make it a matrix. Similarly if the second array is 1-D then a 1 is appended to its shape to make it a matrix. In either case the singleton dimension is removed from the result. - A batched matrix multiplication is performed if the arrays have more than 2 dimensions. The matrix dimensions for the matrix product are the last two dimensions of each input. - All but the last two dimensions of each input are broadcast with one another using standard numpy-style broadcasting semantics. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The matrix product of ``a`` and ``b``. )pbdoc"); m.def( "square", &square, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise square. Args: a (array): Input array. Returns: array: The square of ``a``. )pbdoc"); m.def( "sqrt", &mlx::core::sqrt, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise square root. Args: a (array): Input array. Returns: array: The square root of ``a``. )pbdoc"); m.def( "rsqrt", &rsqrt, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise reciprocal and square root. Args: a (array): Input array. Returns: array: One over the square root of ``a``. )pbdoc"); m.def( "reciprocal", &reciprocal, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise reciprocal. Args: a (array): Input array. Returns: array: The reciprocal of ``a``. )pbdoc"); m.def( "logical_not", [](const ScalarOrArray& a, StreamOrDevice s) { return logical_not(to_array(a), s); }, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise logical not. Args: a (array): Input array or scalar. Returns: array: The boolean array containing the logical not of ``a``. )pbdoc"); m.def( "logaddexp", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return logaddexp(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise log-add-exp. This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. The computation is is a numerically stable version of ``log(exp(a) + exp(b))``. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The log-add-exp of ``a`` and ``b``. )pbdoc"); m.def( "exp", &mlx::core::exp, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise exponential. Args: a (array): Input array. Returns: array: The exponential of ``a``. )pbdoc"); m.def( "erf", &mlx::core::erf, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise error function. .. math:: \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx Args: a (array): Input array. Returns: array: The error function of ``a``. )pbdoc"); m.def( "erfinv", &mlx::core::erfinv, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse of :func:`erf`. Args: a (array): Input array. Returns: array: The inverse error function of ``a``. )pbdoc"); m.def( "sin", &mlx::core::sin, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise sine. Args: a (array): Input array. Returns: array: The sine of ``a``. )pbdoc"); m.def( "cos", &mlx::core::cos, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise cosine. Args: a (array): Input array. Returns: array: The cosine of ``a``. )pbdoc"); m.def( "tan", &mlx::core::tan, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise tangent. Args: a (array): Input array. Returns: array: The tangent of ``a``. )pbdoc"); m.def( "arcsin", &mlx::core::arcsin, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse sine. Args: a (array): Input array. Returns: array: The inverse sine of ``a``. )pbdoc"); m.def( "arccos", &mlx::core::arccos, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse cosine. Args: a (array): Input array. Returns: array: The inverse cosine of ``a``. )pbdoc"); m.def( "arctan", &mlx::core::arctan, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse tangent. Args: a (array): Input array. Returns: array: The inverse tangent of ``a``. )pbdoc"); m.def( "sinh", &mlx::core::sinh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise hyperbolic sine. Args: a (array): Input array. Returns: array: The hyperbolic sine of ``a``. )pbdoc"); m.def( "cosh", &mlx::core::cosh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise hyperbolic cosine. Args: a (array): Input array. Returns: array: The hyperbolic cosine of ``a``. )pbdoc"); m.def( "tanh", &mlx::core::tanh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise hyperbolic tangent. Args: a (array): Input array. Returns: array: The hyperbolic tangent of ``a``. )pbdoc"); m.def( "arcsinh", &mlx::core::arcsinh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse hyperbolic sine. Args: a (array): Input array. Returns: array: The inverse hyperbolic sine of ``a``. )pbdoc"); m.def( "arccosh", &mlx::core::arccosh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse hyperbolic cosine. Args: a (array): Input array. Returns: array: The inverse hyperbolic cosine of ``a``. )pbdoc"); m.def( "arctanh", &mlx::core::arctanh, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise inverse hyperbolic tangent. Args: a (array): Input array. Returns: array: The inverse hyperbolic tangent of ``a``. )pbdoc"); m.def( "log", &mlx::core::log, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise natural logarithm. Args: a (array): Input array. Returns: array: The natural logarithm of ``a``. )pbdoc"); m.def( "log2", &mlx::core::log2, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise base-2 logarithm. Args: a (array): Input array. Returns: array: The base-2 logarithm of ``a``. )pbdoc"); m.def( "log10", &mlx::core::log10, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise base-10 logarithm. Args: a (array): Input array. Returns: array: The base-10 logarithm of ``a``. )pbdoc"); m.def( "log1p", &mlx::core::log1p, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise natural log of one plus the array. Args: a (array): Input array. Returns: array: The natural logarithm of one plus ``a``. )pbdoc"); m.def( "stop_gradient", &stop_gradient, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Stop gradients from being computed. The operation is the identity but it prevents gradients from flowing through the array. Args: a (array): Input array. Returns: array: The unchanged input ``a`` but without gradient flowing through it. )pbdoc"); m.def( "sigmoid", &sigmoid, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise logistic sigmoid. The logistic sigmoid function is: .. math:: \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} Args: a (array): Input array. Returns: array: The logistic sigmoid of ``a``. )pbdoc"); m.def( "power", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return power(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise power operation. Raise the elements of a to the powers in elements of b with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: Bases of ``a`` raised to powers in ``b``. )pbdoc"); { // Disable function signature just for arange which we write manually py::options options; options.disable_function_signatures(); m.def( "arange", [](Scalar stop, std::optional dtype_, StreamOrDevice s) { Dtype dtype = dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop); return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); }, "stop"_a, "dtype"_a = none, "stream"_a = none); m.def( "arange", [](Scalar start, Scalar stop, std::optional dtype_, StreamOrDevice s) { Dtype dtype = dtype_.has_value() ? dtype_.value() : promote_types(scalar_to_dtype(start), scalar_to_dtype(stop)); return arange( scalar_to_double(start), scalar_to_double(stop), dtype, s); }, "start"_a, "stop"_a, "dtype"_a = none, "stream"_a = none); m.def( "arange", [](Scalar stop, Scalar step, std::optional dtype_, StreamOrDevice s) { Dtype dtype = dtype_.has_value() ? dtype_.value() : promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)); return arange( 0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s); }, "stop"_a, "step"_a, "dtype"_a = none, "stream"_a = none); m.def( "arange", [](Scalar start, Scalar stop, Scalar step, std::optional dtype_, StreamOrDevice s) { // Determine the final dtype based on input types Dtype dtype = dtype_.has_value() ? dtype_.value() : promote_types( scalar_to_dtype(start), promote_types( scalar_to_dtype(stop), scalar_to_dtype(step))); return arange( scalar_to_double(start), scalar_to_double(stop), scalar_to_double(step), dtype, s); }, "start"_a, "stop"_a, "step"_a, "dtype"_a = none, "stream"_a = none, R"pbdoc( arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array Generates ranges of numbers. Generate numbers in the half-open interval ``[start, stop)`` in increments of ``step``. Args: start (float or int, optional): Starting value which defaults to ``0``. stop (float or int): Stopping value. step (float or int, optional): Increment which defaults to ``1``. dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``. Returns: array: The range of values. Note: Following the Numpy convention the actual increment used to generate numbers is ``dtype(start + step) - dtype(start)``. This can lead to unexpected results for example if `start + step` is a fractional value and the `dtype` is integral. )pbdoc"); } m.def( "take", [](const array& a, const array& indices, const std::optional& axis, StreamOrDevice s) { if (axis.has_value()) { return take(a, indices, axis.value(), s); } else { return take(a, indices, s); } }, "a"_a, py::pos_only(), "indices"_a, "axis"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( Take elements along an axis. The elements are taken from ``indices`` along the specified axis. If the axis is not specified the array is treated as a flattened 1-D array prior to performing the take. As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``. Args: a (array): Input array. indices (array): Input array with integral type. axis (int, optional): Axis along which to perform the take. If unspecified the array is treated as a flattened 1-D vector. Returns: array: The indexed values of ``a``. )pbdoc"); m.def( "take_along_axis", [](const array& a, const array& indices, const std::optional& axis, StreamOrDevice s) { if (axis.has_value()) { return take_along_axis(a, indices, axis.value(), s); } else { return take_along_axis(reshape(a, {-1}, s), indices, 0, s); } }, "a"_a, py::pos_only(), "indices"_a, "axis"_a, py::kw_only(), "stream"_a = none, R"pbdoc( Take values along an axis at the specified indices. Args: a (array): Input array. indices (array): Indices array. These should be broadcastable with the input array excluding the `axis` dimension. axis (int or None): Axis in the input to take the values from. If ``axis == None`` the array is flattened to 1D prior to the indexing operation. Returns: array: The output array with the specified shape and values. )pbdoc"); m.def( "full", [](const std::variant>& shape, const ScalarOrArray& vals, std::optional dtype, StreamOrDevice s) { if (auto pv = std::get_if(&shape); pv) { return full({*pv}, to_array(vals, dtype), s); } else { return full( std::get>(shape), to_array(vals, dtype), s); } }, "shape"_a, "vals"_a, "dtype"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( Construct an array with the given value. Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` is an :obj:`array` it must be broadcastable to the given ``shape``. Args: shape (int or list(int)): The shape of the output array. vals (float or int or array): Values to fill the array with. dtype (Dtype, optional): Data type of the output array. If unspecified the output type is inferred from ``vals``. Returns: array: The output array with the specified shape and values. )pbdoc"); m.def( "zeros", [](const std::variant>& shape, std::optional dtype, StreamOrDevice s) { auto t = dtype.value_or(float32); if (auto pv = std::get_if(&shape); pv) { return zeros({*pv}, t, s); } else { return zeros(std::get>(shape), t, s); } }, "shape"_a, "dtype"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( Construct an array of zeros. Args: shape (int or list(int)): The shape of the output array. dtype (Dtype, optional): Data type of the output array. If unspecified the output type defaults to ``float32``. Returns: array: The array of zeros with the specified shape. )pbdoc"); m.def( "zeros_like", &zeros_like, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( An array of zeros like the input. Args: a (array): The input to take the shape and type from. Returns: array: The output array filled with zeros. )pbdoc"); m.def( "ones", [](const std::variant>& shape, std::optional dtype, StreamOrDevice s) { auto t = dtype.value_or(float32); if (auto pv = std::get_if(&shape); pv) { return ones({*pv}, t, s); } else { return ones(std::get>(shape), t, s); } }, "shape"_a, "dtype"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( Construct an array of ones. Args: shape (int or list(int)): The shape of the output array. dtype (Dtype, optional): Data type of the output array. If unspecified the output type defaults to ``float32``. Returns: array: The array of ones with the specified shape. )pbdoc"); m.def( "ones_like", &ones_like, "a"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( An array of ones like the input. Args: a (array): The input to take the shape and type from. Returns: array: The output array filled with ones. )pbdoc"); m.def( "allclose", &allclose, "a"_a, "b"_a, py::pos_only(), "rtol"_a = 1e-5, "atol"_a = 1e-8, py::kw_only(), "stream"_a = none, R"pbdoc( Approximate comparison of two arrays. The arrays are considered equal if: .. code-block:: all(abs(a - b) <= (atol + rtol * abs(b))) Note unlike :func:`array_equal`, this function supports numpy-style broadcasting. Args: a (array): Input array. b (array): Input array. rtol (float): Relative tolerance. atol (float): Absolute tolerance. Returns: array: The boolean output scalar indicating if the arrays are close. )pbdoc"); m.def( "all", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( An `and` reduction over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "any", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( An `or` reduction over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "minimum", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return minimum(a, b, s); }, "a"_a, "b"_a, py::pos_only(), "stream"_a = none, R"pbdoc( Element-wise minimum. Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The min of ``a`` and ``b``. )pbdoc"); m.def( "maximum", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); return maximum(a, b, s); }, "a"_a, "b"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Element-wise maximum. Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Args: a (array): Input array or scalar. b (array): Input array or scalar. Returns: array: The max of ``a`` and ``b``. )pbdoc"); m.def( "transpose", [](const array& a, const std::optional>& axes, StreamOrDevice s) { if (axes.has_value()) { return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s); } else { return transpose(a, s); } }, "a"_a, py::pos_only(), "axes"_a = std::nullopt, py::kw_only(), "stream"_a = none, R"pbdoc( Transpose the dimensions of the array. Args: a (array): Input array. axes (list(int), optional): Specifies the source axis for each axis in the new array. The default is to reverse the axes. Returns: array: The transposed array. )pbdoc"); m.def( "sum", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "array"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( Sum reduce the array over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "prod", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( An product reduction over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "min", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( An `min` reduction over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "max", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( An `max` reduction over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "logsumexp", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( A `log-sum-exp` reduction over the given axes. The log-sum-exp reduction is a numerically stable version of: .. code-block:: log(sum(exp(a), axis)) Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the corresponding axes reduced. )pbdoc"); m.def( "mean", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( Compute the mean(s) over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array of means. )pbdoc"); m.def( "var", [](const array& a, const IntOrVec& axis, bool keepdims, int ddof, StreamOrDevice s) { return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, "a"_a, py::pos_only(), "axis"_a = none, "keepdims"_a = false, "ddof"_a = 0, py::kw_only(), "stream"_a = none, R"pbdoc( Compute the variance(s) over the given axes. Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. ddof (int, optional): The divisor to compute the variance is ``N - ddof``, defaults to 0. Returns: array: The output array of variances. )pbdoc"); m.def( "split", [](const array& a, const std::variant>& indices_or_sections, int axis, StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { return split(a, *pv, axis, s); } else { return split( a, std::get>(indices_or_sections), axis, s); } }, "a"_a, py::pos_only(), "indices_or_sections"_a, "axis"_a = 0, py::kw_only(), "stream"_a = none, R"pbdoc( Split an array along a given axis. Args: a (array): Input array. indices_or_sections (int or list(int)): If ``indices_or_sections`` is an integer the array is split into that many sections of equal size. An error is raised if this is not possible. If ``indices_or_sections`` is a list, the list contains the indices of the start of each subarray along the given axis. axis (int, optional): Axis to split along, defaults to `0`. Returns: list(array): A list of split arrays. )pbdoc"); m.def( "argmin", [](const array& a, std::optional axis, bool keepdims, StreamOrDevice s) { if (axis) { return argmin(a, *axis, keepdims, s); } else { return argmin(a, keepdims, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( Indices of the minimum values along the axis. Args: a (array): Input array. axis (int, optional): Optional axis to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the indices of the minimum values. )pbdoc"); m.def( "argmax", [](const array& a, std::optional axis, bool keepdims, StreamOrDevice s) { if (axis) { return argmax(a, *axis, keepdims, s); } else { return argmax(a, keepdims, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, "keepdims"_a = false, py::kw_only(), "stream"_a = none, R"pbdoc( Indices of the maximum values along the axis. Args: a (array): Input array. axis (int, optional): Optional axis to reduce over. If unspecified this defaults to reducing over the entire array. keepdims (bool, optional): Keep reduced axes as singleton dimensions, defaults to `False`. Returns: array: The output array with the indices of the minimum values. )pbdoc"); m.def( "sort", [](const array& a, std::optional axis, StreamOrDevice s) { if (axis) { return sort(a, *axis, s); } else { return sort(a, s); } }, "a"_a, py::pos_only(), "axis"_a = -1, py::kw_only(), "stream"_a = none, R"pbdoc( Returns a sorted copy of the array. Args: a (array): Input array. axis (int or None, optional): Optional axis to sort over. If ``None``, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: array: The sorted array. )pbdoc"); m.def( "argsort", [](const array& a, std::optional axis, StreamOrDevice s) { if (axis) { return argsort(a, *axis, s); } else { return argsort(a, s); } }, "a"_a, py::pos_only(), "axis"_a = -1, py::kw_only(), "stream"_a = none, R"pbdoc( Returns the indices that sort the array. Args: a (array): Input array. axis (int or None, optional): Optional axis to sort over. If ``None``, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: array: The indices that sort the input array. )pbdoc"); m.def( "partition", [](const array& a, int kth, std::optional axis, StreamOrDevice s) { if (axis) { return partition(a, kth, *axis, s); } else { return partition(a, kth, s); } }, "a"_a, py::pos_only(), "kth"_a, "axis"_a = -1, py::kw_only(), "stream"_a = none, R"pbdoc( Returns a partitioned copy of the array such that the smaller ``kth`` elements are first. The ordering of the elements in partitions is undefined. Args: a (array): Input array. kth (int): Element at the ``kth`` index will be in its sorted position in the output. All elements before the kth index will be less or equal to the ``kth`` element and all elements after will be greater or equal to the ``kth`` element in the output. axis (int or None, optional): Optional axis to partition over. If ``None``, this partitions over the flattened array. If unspecified, it defaults to ``-1``. Returns: array: The partitioned array. )pbdoc"); m.def( "argpartition", [](const array& a, int kth, std::optional axis, StreamOrDevice s) { if (axis) { return argpartition(a, kth, *axis, s); } else { return argpartition(a, kth, s); } }, "a"_a, py::pos_only(), "kth"_a, "axis"_a = -1, py::kw_only(), "stream"_a = none, R"pbdoc( Returns the indices that partition the array. The ordering of the elements within a partition in given by the indices is undefined. Args: a (array): Input array. kth (int): Element index at the ``kth`` position in the output will give the sorted position. All indices before the ``kth`` position will be of elements less or equal to the element at the ``kth`` index and all indices after will be of elements greater or equal to the element at the ``kth`` index. axis (int or None, optional): Optional axis to partiton over. If ``None``, this partitions over the flattened array. If unspecified, it defaults to ``-1``. Returns: array: The indices that partition the input array. )pbdoc"); m.def( "topk", [](const array& a, int k, std::optional axis, StreamOrDevice s) { if (axis) { return topk(a, k, *axis, s); } else { return topk(a, k, s); } }, "a"_a, py::pos_only(), "k"_a, "axis"_a = -1, py::kw_only(), "stream"_a = none, R"pbdoc( Returns the ``k`` largest elements from the input along a given axis. The elements will not necessarily be in sorted order. Args: a (array): Input array. k (int): ``k`` top elements to be returned axis (int or None, optional): Optional axis to select over. If ``None``, this selects the top ``k`` elements over the flattened array. If unspecified, it defaults to ``-1``. Returns: array: The top ``k`` elements from the input. )pbdoc"); m.def( "broadcast_to", [](const ScalarOrArray& a, const std::vector& shape, StreamOrDevice s) { return broadcast_to(to_array(a), shape, s); }, "a"_a, py::pos_only(), "shape"_a, py::kw_only(), "stream"_a = none, R"pbdoc( Broadcast an array to the given shape. The broadcasting semantics are the same as Numpy. Args: a (array): Input array. shape (list(int)): The shape to broadcast to. Returns: array: The output array with the new shape. )pbdoc"); m.def( "softmax", [](const array& a, const IntOrVec& axis, StreamOrDevice s) { return softmax(a, get_reduce_axes(axis, a.ndim()), s); }, "a"_a, py::pos_only(), "axis"_a = none, py::kw_only(), "stream"_a = none, R"pbdoc( Perform the softmax along the given axis. This operation is a numerically stable version of: .. code-block:: exp(a) / sum(exp(a), axis, keepdims=True) Args: a (array): Input array. axis (int or list(int), optional): Optional axis or axes to compute the softmax over. If unspecified this performs the softmax over the full array. Returns: array: The output of the softmax. )pbdoc"); m.def( "concatenate", [](const std::vector& arrays, std::optional axis, StreamOrDevice s) { if (axis) { return concatenate(arrays, *axis, s); } else { return concatenate(arrays, s); } }, "arrays"_a, py::pos_only(), "axis"_a = 0, py::kw_only(), "stream"_a = none, R"pbdoc( Concatenate the arrays along the given axis. Args: arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays. axis (int, optional): Optional axis to concatenate along. If unspecified defaults to ``0``. Returns: array: The concatenated array. )pbdoc"); m.def( "pad", [](const array& a, const std::variant< int, std::tuple, std::pair, std::vector>>& pad_width, const ScalarOrArray& constant_value, StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) { return pad(a, *pv, to_array(constant_value), s); } else if (auto pv = std::get_if>(&pad_width); pv) { return pad(a, std::get<0>(*pv), to_array(constant_value), s); } else if (auto pv = std::get_if>(&pad_width); pv) { return pad(a, *pv, to_array(constant_value), s); } else { auto v = std::get>>(pad_width); if (v.size() == 1) { return pad(a, v[0], to_array(constant_value), s); } else { return pad(a, v, to_array(constant_value), s); } } }, "a"_a, py::pos_only(), "pad_width"_a, "constant_values"_a = 0, py::kw_only(), "stream"_a = none, R"pbdoc( Pad an array with a constant value Args: a (array): Input array. pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded values to add to the edges of each axis:``((before_1, after_1), (before_2, after_2), ..., (before_N, after_N))``. If a single pair of integers is passed then ``(before_i, after_i)`` are all the same. If a single integer or tuple with a single integer is passed then all axes are extended by the same number on each side. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. Returns: array: The padded array. )pbdoc"); m.def( "as_strided", [](const array& a, std::optional> shape, std::optional> strides, size_t offset, StreamOrDevice s) { std::vector a_shape = (shape) ? *shape : a.shape(); std::vector a_strides; if (strides) { a_strides = *strides; } else { std::fill_n(std::back_inserter(a_strides), a_shape.size(), 1); for (int i = a_shape.size() - 1; i > 0; i--) { a_strides[i - 1] = a_shape[i] * a_strides[i]; } } return as_strided(a, a_shape, a_strides, offset, s); }, "a"_a, py::pos_only(), "shape"_a = none, "strides"_a = none, "offset"_a = 0, py::kw_only(), "stream"_a = none, R"pbdoc( Create a view into the array with the given shape and strides. The resulting array will always be as if the provided array was row contiguous regardless of the provided arrays storage order and current strides. .. note:: Note that this function should be used with caution as it changes the shape and strides of the array directly. This can lead to the resulting array pointing to invalid memory locations which can result into crashes. Args: a (array): Input array shape (list(int), optional): The shape of the resulting array. If None it defaults to ``a.shape()``. strides (list(int), optional): The strides of the resulting array. If None it defaults to the reverse exclusive cumulative product of ``a.shape()``. offset (int): Skip that many elements from the beginning of the input array. Returns: array: The output array which is the strided view of the input. )pbdoc"); m.def( "cumsum", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cumsum(a, *axis, reverse, inclusive, s); } else { return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: a (array): Input array axis (int, optional): Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. )pbdoc"); m.def( "cumprod", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cumprod(a, *axis, reverse, inclusive, s); } else { return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, R"pbdoc( Return the cumulative product of the elements along the given axis. Args: a (array): Input array axis (int, optional): Optional axis to compute the cumulative product over. If unspecified the cumulative product of the flattened array is returned. reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. )pbdoc"); m.def( "cummax", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cummax(a, *axis, reverse, inclusive, s); } else { return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, R"pbdoc( Return the cumulative maximum of the elements along the given axis. Args: a (array): Input array axis (int, optional): Optional axis to compute the cumulative maximum over. If unspecified the cumulative maximum of the flattened array is returned. reverse (bool): Perform the cumulative maximum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. )pbdoc"); m.def( "cummin", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cummin(a, *axis, reverse, inclusive, s); } else { return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "a"_a, py::pos_only(), "axis"_a = std::nullopt, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, R"pbdoc( Return the cumulative minimum of the elements along the given axis. Args: a (array): Input array axis (int, optional): Optional axis to compute the cumulative minimum over. If unspecified the cumulative minimum of the flattened array is returned. reverse (bool): Perform the cumulative minimum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. )pbdoc"); m.def( "convolve", [](const array& a, const array& v, const std::string& mode, StreamOrDevice s) { if (a.ndim() != 1 || v.ndim() != 1) { throw std::invalid_argument("[convolve] Inputs must be 1D."); } array in = a.size() < v.size() ? v : a; array wt = a.size() < v.size() ? a : v; wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s); in = reshape(in, {1, -1, 1}, s); wt = reshape(wt, {1, -1, 1}, s); int padding = 0; if (mode == "full") { padding = wt.size() - 1; } else if (mode == "valid") { padding = 0; } else if (mode == "same") { // Odd sizes use symmetric padding if (wt.size() % 2) { padding = wt.size() / 2; } else { // Even sizes use asymmetric padding int pad_l = wt.size() / 2; int pad_r = std::max(0, pad_l - 1); in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s); } } else { throw std::invalid_argument("[convolve] Invalid mode."); } array out = conv1d( in, wt, /*stride = */ 1, /*padding = */ padding, /*dilation = */ 1, /*groups = */ 1, s); return reshape(out, {-1}, s); }, "a"_a, "v"_a, py::pos_only(), "mode"_a = "full", py::kw_only(), "stream"_a = none, R"pbdoc( The discrete convolution of 1D arrays. If ``v`` is longer than ``a``, then they are swapped. The conv filter is flipped following signal processing convention. Args: a (array): 1D Input array. v (array): 1D Input array. mode (str, optional): {'full', 'valid', 'same'} Returns: array: The convolved array. )pbdoc"); m.def( "conv1d", &conv1d, "input"_a, "weight"_a, py::pos_only(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, "groups"_a = 1, py::kw_only(), "stream"_a = none, R"pbdoc( 1D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. Args: input (array): input array of shape (``N``, ``H``, ``C_in``) weight (array): weight array of shape (``C_out``, ``H``, ``C_in``) stride (int, optional): kernel stride. Default: ``1``. padding (int, optional): input padding. Default: ``0``. dilation (int, optional): kernel dilation. Default: ``1``. groups (int, optional): input feature groups. Default: ``1``. Returns: array: The convolved array. )pbdoc"); m.def( "conv2d", [](const array& input, const array& weight, const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, int groups, StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; } else { stride_pair = std::get>(stride); } if (auto pv = std::get_if(&padding); pv) { padding_pair = std::pair{*pv, *pv}; } else { padding_pair = std::get>(padding); } if (auto pv = std::get_if(&dilation); pv) { dilation_pair = std::pair{*pv, *pv}; } else { dilation_pair = std::get>(dilation); } return conv2d( input, weight, stride_pair, padding_pair, dilation_pair, groups, s); }, "input"_a, "weight"_a, py::pos_only(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, "groups"_a = 1, py::kw_only(), "stream"_a = none, R"pbdoc( 2D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. Args: input (array): input array of shape ``(N, H, W, C_in)`` weight (array): weight array of shape ``(C_out, H, W, C_in)`` stride (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: ``1``. padding (int or tuple(int), optional): :obj:`tuple` of size 2 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: ``0``. dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` groups (int, optional): input feature groups. Default: ``1``. Returns: array: The convolved array. )pbdoc"); m.def( "save", &mlx_save_helper, "file"_a, "arr"_a, py::pos_only(), "retain_graph"_a = true, py::kw_only(), R"pbdoc( Save the array to a binary file in ``.npy`` format. Args: file (str): File to which the array is saved arr (array): Array to be saved. retain_graph(bool): Optional argument to retain graph during array evaluation before saving. Default: True )pbdoc"); m.def( "savez", [](py::object file, py::args args, const py::kwargs& kwargs) { mlx_savez_helper(file, args, kwargs, /*compressed=*/false); }, "file"_a, py::pos_only(), py::kw_only(), R"pbdoc( Save several arrays to a binary file in uncompressed ``.npz`` format. .. code-block:: python import mlx.core as mx x = mx.ones((10, 10)) mx.savez("my_path.npz", x=x) import mlx.nn as nn from mlx.utils import tree_flatten model = nn.TransformerEncoder(6, 128, 4) flat_params = tree_flatten(model.parameters()) mx.savez("model.npz", **dict(flat_params)) Args: file (file, str): Path to file to which the arrays are saved. args (arrays): Arrays to be saved. kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. )pbdoc"); m.def( "savez_compressed", [](py::object file, py::args args, const py::kwargs& kwargs) { mlx_savez_helper(file, args, kwargs, /*compressed=*/true); }, "file"_a, py::pos_only(), py::kw_only(), R"pbdoc( Save several arrays to a binary file in compressed ``.npz`` format. Args: file (file, str): Path to file to which the arrays are saved. args (arrays): Arrays to be saved. kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. )pbdoc"); m.def( "load", &mlx_load_helper, "file"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Load array(s) from a binary file in ``.npy`` or ``.npz`` format. Args: file (file, str): File in which the array is saved Returns: result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file )pbdoc"); m.def( "where", [](const ScalarOrArray& condition, const ScalarOrArray& x_, const ScalarOrArray& y_, StreamOrDevice s) { auto [x, y] = to_arrays(x_, y_); return where(to_array(condition), x, y, s); }, "condition"_a, "x"_a, "y"_a, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( Select from ``x`` or ``y`` according to ``condition``. The condition and input arrays must be the same shape or broadcastable with each another. Args: condition (array): The condition array. x (array): The input selected from where condition is ``True``. y (array): The input selected from where condition is ``False``. Returns: result (array): The output containing elements selected from ``x`` and ``y``. )pbdoc"); }