// Copyright © 2023 Apple Inc. #include #include #include #include #include "python/src/indexing.h" #include "python/src/utils.h" #include "mlx/ops.h" #include "mlx/transforms.h" #include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; enum PyScalarT { pybool = 0, pyint = 1, pyfloat = 2, pycomplex = 3, }; template py::list to_list(array& a, size_t index, int dim) { py::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { pl.append((a.data()[index])); } else { pl.append(to_list(a, index, dim + 1)); } index += stride; } return pl; } auto to_scalar(array& a) { bool retain_graph = a.is_tracer(); switch (a.dtype()) { case bool_: return py::cast(a.item(retain_graph)); case uint8: return py::cast(a.item(retain_graph)); case uint16: return py::cast(a.item(retain_graph)); case uint32: return py::cast(a.item(retain_graph)); case uint64: return py::cast(a.item(retain_graph)); case int8: return py::cast(a.item(retain_graph)); case int16: return py::cast(a.item(retain_graph)); case int32: return py::cast(a.item(retain_graph)); case int64: return py::cast(a.item(retain_graph)); case float16: return py::cast(static_cast(a.item(retain_graph))); case float32: return py::cast(a.item(retain_graph)); case bfloat16: return py::cast(static_cast(a.item(retain_graph))); case complex64: return py::cast(a.item>(retain_graph)); } } py::object tolist(array& a) { if (a.ndim() == 0) { return to_scalar(a); } a.eval(a.is_tracer()); py::object pl; switch (a.dtype()) { case bool_: return to_list(a, 0, 0); case uint8: return to_list(a, 0, 0); case uint16: return to_list(a, 0, 0); case uint32: return to_list(a, 0, 0); case uint64: return to_list(a, 0, 0); case int8: return to_list(a, 0, 0); case int16: return to_list(a, 0, 0); case int32: return to_list(a, 0, 0); case int64: return to_list(a, 0, 0); case float16: return to_list(a, 0, 0); case float32: return to_list(a, 0, 0); case bfloat16: return to_list(a, 0, 0); case complex64: return to_list>(a, 0, 0); } } template void fill_vector(T list, std::vector& vals) { for (auto l : list) { if (py::isinstance(l)) { fill_vector(l.template cast(), vals); } else if (py::isinstance(*list.begin())) { fill_vector(l.template cast(), vals); } else { vals.push_back(l.template cast()); } } } template PyScalarT validate_shape(T list, const std::vector& shape, int idx) { if (idx >= shape.size()) { throw std::invalid_argument("Initialization encountered extra dimension."); } auto s = shape[idx]; if (py::len(list) != s) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } if (s == 0) { return pyfloat; } PyScalarT type = pybool; for (auto l : list) { PyScalarT t; if (py::isinstance(l)) { t = validate_shape(l.template cast(), shape, idx + 1); } else if (py::isinstance(*list.begin())) { t = validate_shape(l.template cast(), shape, idx + 1); } else if (py::isinstance(l)) { t = pybool; } else if (py::isinstance(l)) { t = pyint; } else if (py::isinstance(l)) { t = pyfloat; } else if (PyComplex_Check(l.ptr())) { t = pycomplex; } else { std::ostringstream msg; msg << "Invalid type in array initialization" << l.get_type() << "."; throw std::invalid_argument(msg.str()); } type = std::max(type, t); } return type; } template void get_shape(T list, std::vector& shape) { shape.push_back(py::len(list)); if (shape.back() > 0) { auto& l = *list.begin(); if (py::isinstance(l)) { return get_shape(l.template cast(), shape); } else if (py::isinstance(l)) { return get_shape(l.template cast(), shape); } } } template array array_from_list(T pl, std::optional dtype) { // Compute the shape std::vector shape; get_shape(pl, shape); // Validate the shape and type auto type = validate_shape(pl, shape, 0); size_t size = 1; for (auto s : shape) { size *= s; } // Make the array switch (type) { case pybool: { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype.value_or(bool_)); } case pyint: { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype.value_or(int32)); } case pyfloat: { std::vector vals; fill_vector(pl, vals); return array(vals.begin(), shape, dtype.value_or(float32)); } case pycomplex: { std::vector> vals; fill_vector(pl, vals); return array( reinterpret_cast(vals.data()), shape, dtype.value_or(complex64)); } } } /////////////////////////////////////////////////////////////////////////////// // MLX -> Numpy /////////////////////////////////////////////////////////////////////////////// size_t elem_to_loc( int elem, const std::vector& shape, const std::vector& strides) { size_t loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); loc += q_and_r.rem * strides[i]; elem = q_and_r.quot; } return loc; } struct PyArrayPayload { array a; }; template py::array_t mlx_array_to_np_t(const array& src) { // Let py::capsule hold onto a copy of the array which holds a shared ptr to // the data const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { delete reinterpret_cast(payload); }); // Collect strides std::vector strides{src.strides().begin(), src.strides().end()}; for (int i = 0; i < src.ndim(); i++) { strides[i] *= src.itemsize(); } // Pack the capsule with the array py::array_t out(src.shape(), strides, src.data(), freeWhenDone); // Mark array as read-only py::detail::array_proxy(out.ptr())->flags &= ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; // Return array return py::array_t(src.shape(), strides, src.data(), out); } template py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) { // Let py::capsule hold onto a copy of the array which holds a shared ptr to // the data const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { delete reinterpret_cast(payload); }); // Collect strides std::vector strides{src.strides().begin(), src.strides().end()}; for (int i = 0; i < src.ndim(); i++) { strides[i] *= src.itemsize(); } // Pack the capsule with the array py::array out(dt, src.shape(), strides, src.data(), freeWhenDone); // Mark array as read-only py::detail::array_proxy(out.ptr())->flags &= ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; // Return array return py::array(dt, src.shape(), strides, src.data(), out); } py::array mlx_array_to_np(const array& src) { // Eval if not already evaled if (!src.is_evaled()) { eval({src}, src.is_tracer()); } switch (src.dtype()) { case bool_: return mlx_array_to_np_t(src); case uint8: return mlx_array_to_np_t(src); case uint16: return mlx_array_to_np_t(src); case uint32: return mlx_array_to_np_t(src); case uint64: return mlx_array_to_np_t(src); case int8: return mlx_array_to_np_t(src); case int16: return mlx_array_to_np_t(src); case int32: return mlx_array_to_np_t(src); case int64: return mlx_array_to_np_t(src); case float16: return mlx_array_to_np_t(src, py::dtype("float16")); case float32: return mlx_array_to_np_t(src); case bfloat16: { auto a = astype(src, float32); eval({a}, src.is_tracer()); return mlx_array_to_np_t(a); } case complex64: return mlx_array_to_np_t(src, py::dtype("complex64")); } } /////////////////////////////////////////////////////////////////////////////// // Numpy -> MLX /////////////////////////////////////////////////////////////////////////////// template array np_array_to_mlx_contiguous( py::array_t np_array, const std::vector& shape, Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor py::buffer_info buf = np_array.request(); const T* data_ptr = static_cast(buf.ptr); return array(data_ptr, shape, dtype); // Note: Leaving the following memoryless copy from np to mx commented // out for the time being since it is unsafe given that the incoming // numpy array may change the underlying data // // Share underlying numpy buffer // // Copy to increase ref count // auto deleter = [np_array](void*) {}; // void* data_ptr = np_array.mutable_data(); // // Use buffer from numpy // return array(data_ptr, deleter, shape, dtype); } template <> array np_array_to_mlx_contiguous( py::array_t, py::array::c_style | py::array::forcecast> np_array, const std::vector& shape, Dtype dtype) { // Get buffer ptr pass to array constructor py::buffer_info buf = np_array.request(); auto data_ptr = static_cast*>(buf.ptr); return array(reinterpret_cast(data_ptr), shape, dtype); } array np_array_to_mlx(py::array np_array, std::optional dtype) { // Compute the shape and size std::vector shape; for (int i = 0; i < np_array.ndim(); i++) { shape.push_back(np_array.shape(i)); } // Get dtype auto type = np_array.dtype(); // Copy data and make array if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int32)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint32)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(bool_)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float32)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float32)); } else if (type.is(py::dtype("float16"))) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float16)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint8)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint16)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint64)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int8)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int16)); } else if (type.is(py::dtype::of())) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int64)); } else if (type.is(py::dtype::of>())) { return np_array_to_mlx_contiguous>( np_array, shape, dtype.value_or(complex64)); } else if (type.is(py::dtype::of>())) { return np_array_to_mlx_contiguous>( np_array, shape, dtype.value_or(complex64)); } else { std::ostringstream msg; msg << "Cannot convert numpy array of type " << type << " to mlx array."; throw std::invalid_argument(msg.str()); } } /////////////////////////////////////////////////////////////////////////////// // Module /////////////////////////////////////////////////////////////////////////////// void init_array(py::module_& m) { // Types py::class_( m, "Dtype", R"pbdoc( An object to hold the type of a :class:`array`. See the :ref:`list of types ` for more details on available data types. )pbdoc") .def_readonly( "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") .def( "__repr__", [](const Dtype& t) { std::ostringstream os; os << t; return os.str(); }) .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }); m.attr("bool_") = py::cast(bool_); m.attr("uint8") = py::cast(uint8); m.attr("uint16") = py::cast(uint16); m.attr("uint32") = py::cast(uint32); m.attr("uint64") = py::cast(uint64); m.attr("int8") = py::cast(int8); m.attr("int16") = py::cast(int16); m.attr("int32") = py::cast(int32); m.attr("int64") = py::cast(int64); m.attr("float16") = py::cast(float16); m.attr("float32") = py::cast(float32); m.attr("bfloat16") = py::cast(bfloat16); m.attr("complex64") = py::cast(complex64); py::class_(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc") .def( py::init([](ScalarOrArray v, std::optional t) { auto arr = to_array(v, t); return astype(arr, t.value_or(arr.dtype())); }), "val"_a, "dtype"_a = std::nullopt) .def( py::init([](std::variant pl, std::optional dtype) { if (auto pv = std::get_if(&pl); pv) { return array_from_list(*pv, dtype); } else { auto v = std::get(pl); return array_from_list(v, dtype); } }), "vals"_a, "dtype"_a = std::nullopt) .def( py::init([](py::array np_array, std::optional dtype) { return np_array_to_mlx(np_array, dtype); }), "vals"_a, "dtype"_a = std::nullopt) .def( py::init([](py::buffer np_buffer, std::optional dtype) { return np_array_to_mlx(np_buffer, dtype); }), "vals"_a, "dtype"_a = std::nullopt) .def_property_readonly( "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") .def_property_readonly( "ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") // TODO, this makes a deep copy of the shape // implement alternatives to use reference // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html .def_property_readonly( "shape", [](const array& a) { return a.shape(); }, R"pbdoc( The shape of the array as a Python list. Returns: list(int): A list containing the sizes of each dimension. )pbdoc") .def_property_readonly( "dtype", &array::dtype, R"pbdoc( The array's :class:`Dtype`. )pbdoc") .def( "item", &to_scalar, R"pbdoc( Access the value of a scalar array. Returns: Standard Python scalar. )pbdoc") .def( "tolist", &tolist, R"pbdoc( Convert the array to a Python :class:`list`. Returns: list: The Python list. If the array is a scalar then a standard Python scalar is returned. If the array has more than one dimension then the result is a nested list of lists. The value type of the list correpsonding to the last dimension is either ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array. )pbdoc") .def("__array__", &mlx_array_to_np) .def( "astype", &astype, "dtype"_a, "stream"_a = none, R"pbdoc( Cast the array to a specified type. Args: dtype (Dtype): Type to which the array is cast. stream (Stream): Stream (or device) for the operation. Returns: array: The array with type ``dtype``. )pbdoc") .def("__getitem__", mlx_get_item) .def("__setitem__", mlx_set_item) .def( "__len__", [](const array& a) { if (a.ndim() == 0) { throw py::type_error("len() 0-dimensional array."); } return a.shape(0); }) .def( "__iter__", [](const array& a) { return py::make_iterator(a); }, py::keep_alive<0, 1>()) .def( "__add__", [](const array& a, const ScalarOrArray v) { return add(a, to_array(v, a.dtype())); }, "other"_a) .def( "__radd__", [](const array& a, const ScalarOrArray v) { return add(a, to_array(v, a.dtype())); }, "other"_a) .def( "__sub__", [](const array& a, const ScalarOrArray v) { return subtract(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rsub__", [](const array& a, const ScalarOrArray v) { return subtract(to_array(v, a.dtype()), a); }, "other"_a) .def( "__mul__", [](const array& a, const ScalarOrArray v) { return multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rmul__", [](const array& a, const ScalarOrArray v) { return multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__truediv__", [](const array& a, const ScalarOrArray v) { return divide(a, to_array(v, float32)); }, "other"_a) .def( "__div__", [](const array& a, const ScalarOrArray v) { return divide(a, to_array(v, float32)); }, "other"_a) .def( "__rtruediv__", [](const array& a, const ScalarOrArray v) { return divide(to_array(v, float32), a); }, "other"_a) .def( "__rdiv__", [](const array& a, const ScalarOrArray v) { return divide(to_array(v, float32), a); }, "other"_a) .def( "__mod__", [](const array& a, const ScalarOrArray v) { return remainder(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rmod__", [](const array& a, const ScalarOrArray v) { return remainder(to_array(v, a.dtype()), a); }, "other"_a) .def( "__eq__", [](const array& a, const ScalarOrArray v) { return equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__lt__", [](const array& a, const ScalarOrArray v) { return less(a, to_array(v, a.dtype())); }, "other"_a) .def( "__le__", [](const array& a, const ScalarOrArray v) { return less_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__gt__", [](const array& a, const ScalarOrArray v) { return greater(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ge__", [](const array& a, const ScalarOrArray v) { return greater_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ne__", [](const array& a, const ScalarOrArray v) { return not_equal(a, to_array(v, a.dtype())); }, "other"_a) .def("__neg__", [](const array& a) { return -a; }) .def("__bool__", [](array& a) { return py::bool_(to_scalar(a)); }) .def( "__repr__", [](array& a) { if (!a.is_evaled()) { a.eval(a.is_tracer()); } std::ostringstream os; os << a; return os.str(); }) .def( "__matmul__", [](array& a, array& other) { return matmul(a, other); }) .def( "__pow__", [](const array& a, const ScalarOrArray v) { return power(a, to_array(v, a.dtype())); }, "other"_a) .def( "reshape", [](const array& a, py::args shape, StreamOrDevice s) { if (shape.size() == 1) { py::object arg = shape[0]; if (!py::isinstance(arg)) { return reshape(a, py::cast>(arg), s); } } return reshape(a, py::cast>(shape), s); }, py::kw_only(), "stream"_a = none, R"pbdoc( Equivalent to :func:`reshape` but the shape can be passed either as a tuple or as separate arguments. See :func:`reshape` for full documentation. )pbdoc") .def( "squeeze", [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { if (std::holds_alternative(v)) { return squeeze(a, s); } else if (auto pv = std::get_if(&v); pv) { return squeeze(a, *pv, s); } else { return squeeze(a, std::get>(v), s); } }, "axis"_a = none, py::kw_only(), "stream"_a = none, R"pbdoc( See :func:`squeeze`. )pbdoc") .def( "abs", &mlx::core::abs, py::kw_only(), "stream"_a = none, "See :func:`abs`.") .def( "square", &square, py::kw_only(), "stream"_a = none, "See :func:`square`.") .def( "sqrt", &mlx::core::sqrt, py::kw_only(), "stream"_a = none, "See :func:`sqrt`.") .def( "rsqrt", &rsqrt, py::kw_only(), "stream"_a = none, "See :func:`rsqrt`.") .def( "reciprocal", &reciprocal, py::kw_only(), "stream"_a = none, "See :func:`reciprocal`.") .def( "exp", &mlx::core::exp, py::kw_only(), "stream"_a = none, "See :func:`exp`.") .def( "log", &mlx::core::log, py::kw_only(), "stream"_a = none, "See :func:`log`.") .def( "log2", &mlx::core::log2, py::kw_only(), "stream"_a = none, "See :func:`log2`.") .def( "log10", &mlx::core::log10, py::kw_only(), "stream"_a = none, "See :func:`log10`.") .def( "sin", &mlx::core::sin, py::kw_only(), "stream"_a = none, "See :func:`sin`.") .def( "cos", &mlx::core::cos, py::kw_only(), "stream"_a = none, "See :func:`cos`.") .def( "log1p", &mlx::core::log1p, py::kw_only(), "stream"_a = none, "See :func:`log1p`.") .def( "all", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`all`.") .def( "any", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`any`.") .def( "transpose", [](const array& a, py::args axes, StreamOrDevice s) { if (axes.size() > 0) { if (axes.size() == 1) { py::object arg = axes[0]; if (!py::isinstance(arg)) { return transpose(a, py::cast>(arg), s); } } return transpose(a, py::cast>(axes), s); } else { return transpose(a, s); } }, py::kw_only(), "stream"_a = none, R"pbdoc( Equivalent to :func:`transpose` but the axes can be passed either as a tuple or as separate arguments. See :func:`transpose` for full documentation. )pbdoc") .def_property_readonly( "T", [](const array& a) { return transpose(a); }, "Equivalent to calling ``self.transpose()`` with no arguments.") .def( "sum", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`sum`.") .def( "prod", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`prod`.") .def( "min", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`min`.") .def( "max", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`max`.") .def( "logsumexp", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`logsumexp`.") .def( "mean", [](const array& a, const IntOrVec& axis, bool keepdims, StreamOrDevice s) { return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`mean`.") .def( "var", [](const array& a, const IntOrVec& axis, bool keepdims, int ddof, StreamOrDevice s) { return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, "axis"_a = none, "keepdims"_a = false, "ddof"_a = 0, py::kw_only(), "stream"_a = none, "See :func:`var`.") .def( "split", [](const array& a, const std::variant>& indices_or_sections, int axis, StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { return split(a, *pv, axis, s); } else { return split( a, std::get>(indices_or_sections), axis, s); } }, "indices_or_sections"_a, "axis"_a = 0, py::kw_only(), "stream"_a = none, "See :func:`split`.") .def( "argmin", [](const array& a, std::optional axis, bool keepdims, StreamOrDevice s) { if (axis) { return argmin(a, *axis, keepdims, s); } else { return argmin(a, keepdims, s); } }, "axis"_a = std::nullopt, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`argmin`.") .def( "argmax", [](const array& a, std::optional axis, bool keepdims, StreamOrDevice s) { if (axis) { return argmax(a, *axis, keepdims, s); } else { return argmax(a, keepdims, s); } }, "axis"_a = none, "keepdims"_a = false, py::kw_only(), "stream"_a = none, "See :func:`argmax`.") .def( "cumsum", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cumsum(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = none, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, "See :func:`cumsum`.") .def( "cumprod", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cumprod(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = none, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, "See :func:`cumprod`.") .def( "cummax", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cummax(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = none, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, "See :func:`cummax`.") .def( "cummin", [](const array& a, std::optional axis, bool reverse, bool inclusive, StreamOrDevice s) { if (axis) { return cummin(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = none, py::kw_only(), "reverse"_a = false, "inclusive"_a = true, "stream"_a = none, "See :func:`cummin`."); }