2023-12-01 03:12:53 +08:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
#include <cstdint>
|
|
|
|
#include <cstring>
|
|
|
|
#include <sstream>
|
|
|
|
|
|
|
|
#include <pybind11/numpy.h>
|
|
|
|
|
|
|
|
#include "python/src/indexing.h"
|
|
|
|
#include "python/src/utils.h"
|
|
|
|
|
|
|
|
#include "mlx/ops.h"
|
|
|
|
#include "mlx/transforms.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
using namespace py::literals;
|
|
|
|
|
|
|
|
enum PyScalarT {
|
|
|
|
pybool = 0,
|
|
|
|
pyint = 1,
|
|
|
|
pyfloat = 2,
|
|
|
|
pycomplex = 3,
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
py::list to_list(array& a, size_t index, int dim) {
|
|
|
|
py::list pl;
|
|
|
|
auto stride = a.strides()[dim];
|
|
|
|
for (int i = 0; i < a.shape(dim); ++i) {
|
|
|
|
if (dim == a.ndim() - 1) {
|
|
|
|
pl.append((a.data<T>()[index]));
|
|
|
|
} else {
|
|
|
|
pl.append(to_list<T>(a, index, dim + 1));
|
|
|
|
}
|
|
|
|
index += stride;
|
|
|
|
}
|
|
|
|
return pl;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto to_scalar(array& a) {
|
|
|
|
bool retain_graph = a.is_tracer();
|
|
|
|
switch (a.dtype()) {
|
|
|
|
case bool_:
|
|
|
|
return py::cast(a.item<bool>(retain_graph));
|
|
|
|
case uint8:
|
|
|
|
return py::cast(a.item<uint8_t>(retain_graph));
|
|
|
|
case uint16:
|
|
|
|
return py::cast(a.item<uint16_t>(retain_graph));
|
|
|
|
case uint32:
|
|
|
|
return py::cast(a.item<uint32_t>(retain_graph));
|
|
|
|
case uint64:
|
|
|
|
return py::cast(a.item<uint64_t>(retain_graph));
|
|
|
|
case int8:
|
|
|
|
return py::cast(a.item<int8_t>(retain_graph));
|
|
|
|
case int16:
|
|
|
|
return py::cast(a.item<int16_t>(retain_graph));
|
|
|
|
case int32:
|
|
|
|
return py::cast(a.item<int32_t>(retain_graph));
|
|
|
|
case int64:
|
|
|
|
return py::cast(a.item<int64_t>(retain_graph));
|
|
|
|
case float16:
|
|
|
|
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
|
|
|
case float32:
|
|
|
|
return py::cast(a.item<float>(retain_graph));
|
|
|
|
case bfloat16:
|
|
|
|
return py::cast(static_cast<float>(a.item<float16_t>(retain_graph)));
|
|
|
|
case complex64:
|
|
|
|
return py::cast(a.item<std::complex<float>>(retain_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
py::object tolist(array& a) {
|
|
|
|
if (a.ndim() == 0) {
|
|
|
|
return to_scalar(a);
|
|
|
|
}
|
|
|
|
a.eval(a.is_tracer());
|
|
|
|
py::object pl;
|
|
|
|
switch (a.dtype()) {
|
|
|
|
case bool_:
|
|
|
|
return to_list<bool>(a, 0, 0);
|
|
|
|
case uint8:
|
|
|
|
return to_list<uint8_t>(a, 0, 0);
|
|
|
|
case uint16:
|
|
|
|
return to_list<uint16_t>(a, 0, 0);
|
|
|
|
case uint32:
|
|
|
|
return to_list<uint32_t>(a, 0, 0);
|
|
|
|
case uint64:
|
|
|
|
return to_list<uint64_t>(a, 0, 0);
|
|
|
|
case int8:
|
|
|
|
return to_list<int8_t>(a, 0, 0);
|
|
|
|
case int16:
|
|
|
|
return to_list<int16_t>(a, 0, 0);
|
|
|
|
case int32:
|
|
|
|
return to_list<int32_t>(a, 0, 0);
|
|
|
|
case int64:
|
|
|
|
return to_list<int64_t>(a, 0, 0);
|
|
|
|
case float16:
|
|
|
|
return to_list<float16_t>(a, 0, 0);
|
|
|
|
case float32:
|
|
|
|
return to_list<float>(a, 0, 0);
|
|
|
|
case bfloat16:
|
|
|
|
return to_list<float16_t>(a, 0, 0);
|
|
|
|
case complex64:
|
|
|
|
return to_list<std::complex<float>>(a, 0, 0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, typename U>
|
|
|
|
void fill_vector(T list, std::vector<U>& vals) {
|
|
|
|
for (auto l : list) {
|
|
|
|
if (py::isinstance<py::list>(l)) {
|
|
|
|
fill_vector(l.template cast<py::list>(), vals);
|
|
|
|
} else if (py::isinstance<py::tuple>(*list.begin())) {
|
|
|
|
fill_vector(l.template cast<py::tuple>(), vals);
|
|
|
|
} else {
|
|
|
|
vals.push_back(l.template cast<U>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
|
|
|
if (idx >= shape.size()) {
|
|
|
|
throw std::invalid_argument("Initialization encountered extra dimension.");
|
|
|
|
}
|
|
|
|
auto s = shape[idx];
|
|
|
|
if (py::len(list) != s) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Initialization encountered non-uniform length.");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (s == 0) {
|
|
|
|
return pyfloat;
|
|
|
|
}
|
|
|
|
|
|
|
|
PyScalarT type = pybool;
|
|
|
|
for (auto l : list) {
|
|
|
|
PyScalarT t;
|
|
|
|
if (py::isinstance<py::list>(l)) {
|
|
|
|
t = validate_shape(l.template cast<py::list>(), shape, idx + 1);
|
|
|
|
} else if (py::isinstance<py::tuple>(*list.begin())) {
|
|
|
|
t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1);
|
|
|
|
} else if (py::isinstance<py::bool_>(l)) {
|
|
|
|
t = pybool;
|
|
|
|
} else if (py::isinstance<py::int_>(l)) {
|
|
|
|
t = pyint;
|
|
|
|
} else if (py::isinstance<py::float_>(l)) {
|
|
|
|
t = pyfloat;
|
|
|
|
} else if (PyComplex_Check(l.ptr())) {
|
|
|
|
t = pycomplex;
|
|
|
|
} else {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Invalid type in array initialization" << l.get_type() << ".";
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
type = std::max(type, t);
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void get_shape(T list, std::vector<int>& shape) {
|
|
|
|
shape.push_back(py::len(list));
|
|
|
|
if (shape.back() > 0) {
|
|
|
|
auto& l = *list.begin();
|
|
|
|
if (py::isinstance<py::list>(l)) {
|
|
|
|
return get_shape(l.template cast<py::list>(), shape);
|
|
|
|
} else if (py::isinstance<py::tuple>(l)) {
|
|
|
|
return get_shape(l.template cast<py::tuple>(), shape);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
array array_from_list(T pl, std::optional<Dtype> dtype) {
|
|
|
|
// Compute the shape
|
|
|
|
std::vector<int> shape;
|
|
|
|
get_shape(pl, shape);
|
|
|
|
|
|
|
|
// Validate the shape and type
|
|
|
|
auto type = validate_shape(pl, shape, 0);
|
|
|
|
|
|
|
|
size_t size = 1;
|
|
|
|
for (auto s : shape) {
|
|
|
|
size *= s;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Make the array
|
|
|
|
switch (type) {
|
|
|
|
case pybool: {
|
|
|
|
std::vector<bool> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype.value_or(bool_));
|
|
|
|
}
|
|
|
|
case pyint: {
|
|
|
|
std::vector<int> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype.value_or(int32));
|
|
|
|
}
|
|
|
|
case pyfloat: {
|
|
|
|
std::vector<float> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype.value_or(float32));
|
|
|
|
}
|
|
|
|
case pycomplex: {
|
|
|
|
std::vector<std::complex<float>> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(
|
|
|
|
reinterpret_cast<complex64_t*>(vals.data()),
|
|
|
|
shape,
|
|
|
|
dtype.value_or(complex64));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MLX -> Numpy
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
size_t elem_to_loc(
|
|
|
|
int elem,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
const std::vector<size_t>& strides) {
|
|
|
|
size_t loc = 0;
|
|
|
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
|
|
auto q_and_r = ldiv(elem, shape[i]);
|
|
|
|
loc += q_and_r.rem * strides[i];
|
|
|
|
elem = q_and_r.quot;
|
|
|
|
}
|
|
|
|
return loc;
|
|
|
|
}
|
|
|
|
|
|
|
|
struct PyArrayPayload {
|
|
|
|
array a;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
py::array_t<T> mlx_array_to_np_t(const array& src) {
|
|
|
|
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
|
|
|
// the data
|
|
|
|
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
|
|
|
delete reinterpret_cast<PyArrayPayload*>(payload);
|
|
|
|
});
|
|
|
|
// Collect strides
|
|
|
|
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
|
|
|
for (int i = 0; i < src.ndim(); i++) {
|
|
|
|
strides[i] *= src.itemsize();
|
|
|
|
}
|
|
|
|
// Pack the capsule with the array
|
|
|
|
py::array_t<T> out(src.shape(), strides, src.data<T>(), freeWhenDone);
|
|
|
|
// Mark array as read-only
|
|
|
|
py::detail::array_proxy(out.ptr())->flags &=
|
|
|
|
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
|
|
// Return array
|
|
|
|
return py::array_t(src.shape(), strides, src.data<T>(), out);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) {
|
|
|
|
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
|
|
|
// the data
|
|
|
|
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
|
|
|
delete reinterpret_cast<PyArrayPayload*>(payload);
|
|
|
|
});
|
|
|
|
// Collect strides
|
|
|
|
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
|
|
|
for (int i = 0; i < src.ndim(); i++) {
|
|
|
|
strides[i] *= src.itemsize();
|
|
|
|
}
|
|
|
|
// Pack the capsule with the array
|
|
|
|
py::array out(dt, src.shape(), strides, src.data<T>(), freeWhenDone);
|
|
|
|
// Mark array as read-only
|
|
|
|
py::detail::array_proxy(out.ptr())->flags &=
|
|
|
|
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
|
|
// Return array
|
|
|
|
return py::array(dt, src.shape(), strides, src.data<T>(), out);
|
|
|
|
}
|
|
|
|
|
|
|
|
py::array mlx_array_to_np(const array& src) {
|
|
|
|
// Eval if not already evaled
|
|
|
|
if (!src.is_evaled()) {
|
|
|
|
eval({src}, src.is_tracer());
|
|
|
|
}
|
|
|
|
|
|
|
|
switch (src.dtype()) {
|
|
|
|
case bool_:
|
|
|
|
return mlx_array_to_np_t<bool>(src);
|
|
|
|
case uint8:
|
|
|
|
return mlx_array_to_np_t<uint8_t>(src);
|
|
|
|
case uint16:
|
|
|
|
return mlx_array_to_np_t<uint16_t>(src);
|
|
|
|
case uint32:
|
|
|
|
return mlx_array_to_np_t<uint32_t>(src);
|
|
|
|
case uint64:
|
|
|
|
return mlx_array_to_np_t<uint64_t>(src);
|
|
|
|
case int8:
|
|
|
|
return mlx_array_to_np_t<int8_t>(src);
|
|
|
|
case int16:
|
|
|
|
return mlx_array_to_np_t<int16_t>(src);
|
|
|
|
case int32:
|
|
|
|
return mlx_array_to_np_t<int32_t>(src);
|
|
|
|
case int64:
|
|
|
|
return mlx_array_to_np_t<int64_t>(src);
|
|
|
|
case float16:
|
|
|
|
return mlx_array_to_np_t<float16_t>(src, py::dtype("float16"));
|
|
|
|
case float32:
|
|
|
|
return mlx_array_to_np_t<float>(src);
|
|
|
|
case bfloat16: {
|
|
|
|
auto a = astype(src, float32);
|
|
|
|
eval({a}, src.is_tracer());
|
|
|
|
return mlx_array_to_np_t<float>(a);
|
|
|
|
}
|
|
|
|
case complex64:
|
|
|
|
return mlx_array_to_np_t<complex64_t>(src, py::dtype("complex64"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Numpy -> MLX
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
array np_array_to_mlx_contiguous(
|
|
|
|
py::array_t<T, py::array::c_style | py::array::forcecast> np_array,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
Dtype dtype) {
|
|
|
|
// Make a copy of the numpy buffer
|
|
|
|
// Get buffer ptr pass to array constructor
|
|
|
|
py::buffer_info buf = np_array.request();
|
|
|
|
const T* data_ptr = static_cast<T*>(buf.ptr);
|
|
|
|
return array(data_ptr, shape, dtype);
|
|
|
|
|
|
|
|
// Note: Leaving the following memoryless copy from np to mx commented
|
|
|
|
// out for the time being since it is unsafe given that the incoming
|
|
|
|
// numpy array may change the underlying data
|
|
|
|
|
|
|
|
// // Share underlying numpy buffer
|
|
|
|
// // Copy to increase ref count
|
|
|
|
// auto deleter = [np_array](void*) {};
|
|
|
|
// void* data_ptr = np_array.mutable_data();
|
|
|
|
// // Use buffer from numpy
|
|
|
|
// return array(data_ptr, deleter, shape, dtype);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
array np_array_to_mlx_contiguous(
|
|
|
|
py::array_t<std::complex<float>, py::array::c_style | py::array::forcecast>
|
|
|
|
np_array,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
Dtype dtype) {
|
|
|
|
// Get buffer ptr pass to array constructor
|
|
|
|
py::buffer_info buf = np_array.request();
|
|
|
|
auto data_ptr = static_cast<std::complex<float>*>(buf.ptr);
|
|
|
|
return array(reinterpret_cast<complex64_t*>(data_ptr), shape, dtype);
|
|
|
|
}
|
|
|
|
|
|
|
|
array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
|
|
|
// Compute the shape and size
|
|
|
|
std::vector<int> shape;
|
|
|
|
for (int i = 0; i < np_array.ndim(); i++) {
|
|
|
|
shape.push_back(np_array.shape(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get dtype
|
|
|
|
auto type = np_array.dtype();
|
|
|
|
|
|
|
|
// Copy data and make array
|
|
|
|
if (type.is(py::dtype::of<int>())) {
|
|
|
|
return np_array_to_mlx_contiguous<int32_t>(
|
|
|
|
np_array, shape, dtype.value_or(int32));
|
|
|
|
} else if (type.is(py::dtype::of<uint32_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<uint32_t>(
|
|
|
|
np_array, shape, dtype.value_or(uint32));
|
|
|
|
} else if (type.is(py::dtype::of<bool>())) {
|
|
|
|
return np_array_to_mlx_contiguous<bool>(
|
|
|
|
np_array, shape, dtype.value_or(bool_));
|
|
|
|
} else if (type.is(py::dtype::of<double>())) {
|
|
|
|
return np_array_to_mlx_contiguous<double>(
|
|
|
|
np_array, shape, dtype.value_or(float32));
|
|
|
|
} else if (type.is(py::dtype::of<float>())) {
|
|
|
|
return np_array_to_mlx_contiguous<float>(
|
|
|
|
np_array, shape, dtype.value_or(float32));
|
|
|
|
} else if (type.is(py::dtype("float16"))) {
|
|
|
|
return np_array_to_mlx_contiguous<float>(
|
|
|
|
np_array, shape, dtype.value_or(float16));
|
|
|
|
} else if (type.is(py::dtype::of<uint8_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<uint8_t>(
|
|
|
|
np_array, shape, dtype.value_or(uint8));
|
|
|
|
} else if (type.is(py::dtype::of<uint16_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<uint16_t>(
|
|
|
|
np_array, shape, dtype.value_or(uint16));
|
|
|
|
} else if (type.is(py::dtype::of<uint64_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<uint64_t>(
|
|
|
|
np_array, shape, dtype.value_or(uint64));
|
|
|
|
} else if (type.is(py::dtype::of<int8_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<int8_t>(
|
|
|
|
np_array, shape, dtype.value_or(int8));
|
|
|
|
} else if (type.is(py::dtype::of<int16_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<int16_t>(
|
|
|
|
np_array, shape, dtype.value_or(int16));
|
|
|
|
} else if (type.is(py::dtype::of<int64_t>())) {
|
|
|
|
return np_array_to_mlx_contiguous<int64_t>(
|
|
|
|
np_array, shape, dtype.value_or(int64));
|
|
|
|
} else if (type.is(py::dtype::of<std::complex<float>>())) {
|
|
|
|
return np_array_to_mlx_contiguous<std::complex<float>>(
|
|
|
|
np_array, shape, dtype.value_or(complex64));
|
|
|
|
} else if (type.is(py::dtype::of<std::complex<double>>())) {
|
|
|
|
return np_array_to_mlx_contiguous<std::complex<float>>(
|
|
|
|
np_array, shape, dtype.value_or(complex64));
|
|
|
|
} else {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Module
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
void init_array(py::module_& m) {
|
|
|
|
// Types
|
|
|
|
py::class_<Dtype>(
|
|
|
|
m,
|
|
|
|
"Dtype",
|
|
|
|
R"pbdoc(
|
|
|
|
An object to hold the type of a :class:`array`.
|
|
|
|
|
|
|
|
See the :ref:`list of types <data_types>` for more details
|
|
|
|
on available data types.
|
|
|
|
)pbdoc")
|
|
|
|
.def_readonly(
|
|
|
|
"size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
|
|
|
.def(
|
|
|
|
"__repr__",
|
|
|
|
[](const Dtype& t) {
|
|
|
|
std::ostringstream os;
|
|
|
|
os << t;
|
|
|
|
return os.str();
|
|
|
|
})
|
|
|
|
.def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; });
|
|
|
|
m.attr("bool_") = py::cast(bool_);
|
|
|
|
m.attr("uint8") = py::cast(uint8);
|
|
|
|
m.attr("uint16") = py::cast(uint16);
|
|
|
|
m.attr("uint32") = py::cast(uint32);
|
|
|
|
m.attr("uint64") = py::cast(uint64);
|
|
|
|
m.attr("int8") = py::cast(int8);
|
|
|
|
m.attr("int16") = py::cast(int16);
|
|
|
|
m.attr("int32") = py::cast(int32);
|
|
|
|
m.attr("int64") = py::cast(int64);
|
|
|
|
m.attr("float16") = py::cast(float16);
|
|
|
|
m.attr("float32") = py::cast(float32);
|
|
|
|
m.attr("bfloat16") = py::cast(bfloat16);
|
|
|
|
m.attr("complex64") = py::cast(complex64);
|
|
|
|
|
|
|
|
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
|
|
|
|
.def(
|
|
|
|
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
|
|
|
|
auto arr = to_array(v, t);
|
|
|
|
return astype(arr, t.value_or(arr.dtype()));
|
|
|
|
}),
|
|
|
|
"val"_a,
|
|
|
|
"dtype"_a = std::nullopt)
|
|
|
|
.def(
|
|
|
|
py::init([](std::variant<py::list, py::tuple> pl,
|
|
|
|
std::optional<Dtype> dtype) {
|
|
|
|
if (auto pv = std::get_if<py::list>(&pl); pv) {
|
|
|
|
return array_from_list(*pv, dtype);
|
|
|
|
} else {
|
|
|
|
auto v = std::get<py::tuple>(pl);
|
|
|
|
return array_from_list(v, dtype);
|
|
|
|
}
|
|
|
|
}),
|
|
|
|
"vals"_a,
|
|
|
|
"dtype"_a = std::nullopt)
|
|
|
|
.def(
|
|
|
|
py::init([](py::array np_array, std::optional<Dtype> dtype) {
|
|
|
|
return np_array_to_mlx(np_array, dtype);
|
|
|
|
}),
|
|
|
|
"vals"_a,
|
|
|
|
"dtype"_a = std::nullopt)
|
|
|
|
.def(
|
|
|
|
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
|
|
|
|
return np_array_to_mlx(np_buffer, dtype);
|
|
|
|
}),
|
|
|
|
"vals"_a,
|
|
|
|
"dtype"_a = std::nullopt)
|
|
|
|
.def_property_readonly(
|
|
|
|
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
|
|
|
.def_property_readonly(
|
|
|
|
"ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
|
|
|
// TODO, this makes a deep copy of the shape
|
|
|
|
// implement alternatives to use reference
|
|
|
|
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
|
|
|
.def_property_readonly(
|
|
|
|
"shape",
|
|
|
|
[](const array& a) { return a.shape(); },
|
|
|
|
R"pbdoc(
|
|
|
|
The shape of the array as a Python list.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list(int): A list containing the sizes of each dimension.
|
|
|
|
)pbdoc")
|
|
|
|
.def_property_readonly(
|
|
|
|
"dtype",
|
|
|
|
&array::dtype,
|
|
|
|
R"pbdoc(
|
|
|
|
The array's :class:`Dtype`.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"item",
|
|
|
|
&to_scalar,
|
|
|
|
R"pbdoc(
|
|
|
|
Access the value of a scalar array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Standard Python scalar.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"tolist",
|
|
|
|
&tolist,
|
|
|
|
R"pbdoc(
|
|
|
|
Convert the array to a Python :class:`list`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: The Python list.
|
|
|
|
|
|
|
|
If the array is a scalar then a standard Python scalar is returned.
|
|
|
|
|
|
|
|
If the array has more than one dimension then the result is a nested
|
|
|
|
list of lists.
|
|
|
|
|
|
|
|
The value type of the list correpsonding to the last dimension is either
|
|
|
|
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
|
|
|
)pbdoc")
|
|
|
|
.def("__array__", &mlx_array_to_np)
|
|
|
|
.def(
|
|
|
|
"astype",
|
|
|
|
&astype,
|
|
|
|
"dtype"_a,
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
Cast the array to a specified type.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dtype (Dtype): Type to which the array is cast.
|
|
|
|
stream (Stream): Stream (or device) for the operation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with type ``dtype``.
|
|
|
|
)pbdoc")
|
|
|
|
.def("__getitem__", mlx_get_item)
|
|
|
|
.def("__setitem__", mlx_set_item)
|
|
|
|
.def(
|
|
|
|
"__len__",
|
|
|
|
[](const array& a) {
|
|
|
|
if (a.ndim() == 0) {
|
|
|
|
throw py::type_error("len() 0-dimensional array.");
|
|
|
|
}
|
|
|
|
return a.shape(0);
|
|
|
|
})
|
|
|
|
.def(
|
|
|
|
"__iter__",
|
|
|
|
[](const array& a) { return py::make_iterator(a); },
|
|
|
|
py::keep_alive<0, 1>())
|
|
|
|
.def(
|
|
|
|
"__add__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return add(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__radd__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return add(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__sub__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return subtract(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__rsub__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return subtract(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__mul__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return multiply(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__rmul__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return multiply(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__truediv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return divide(a, to_array(v, float32));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__div__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return divide(a, to_array(v, float32));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__rtruediv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return divide(to_array(v, float32), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__rdiv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return divide(to_array(v, float32), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
2023-12-09 07:08:52 +08:00
|
|
|
.def(
|
|
|
|
"__mod__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return remainder(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__rmod__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return remainder(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__eq__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__lt__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return less(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__le__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return less_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__gt__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return greater(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ge__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return greater_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ne__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return not_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def("__neg__", [](const array& a) { return -a; })
|
|
|
|
.def("__bool__", [](array& a) { return py::bool_(to_scalar(a)); })
|
|
|
|
.def(
|
|
|
|
"__repr__",
|
|
|
|
[](array& a) {
|
|
|
|
if (!a.is_evaled()) {
|
|
|
|
a.eval(a.is_tracer());
|
|
|
|
}
|
|
|
|
std::ostringstream os;
|
|
|
|
os << a;
|
|
|
|
return os.str();
|
|
|
|
})
|
|
|
|
.def(
|
|
|
|
"__matmul__", [](array& a, array& other) { return matmul(a, other); })
|
|
|
|
.def(
|
|
|
|
"__pow__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
|
|
|
return power(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"reshape",
|
|
|
|
[](const array& a, py::args shape, StreamOrDevice s) {
|
|
|
|
if (shape.size() == 1) {
|
|
|
|
py::object arg = shape[0];
|
|
|
|
if (!py::isinstance<py::int_>(arg)) {
|
|
|
|
return reshape(a, py::cast<std::vector<int>>(arg), s);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return reshape(a, py::cast<std::vector<int>>(shape), s);
|
|
|
|
},
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
Equivalent to :func:`reshape` but the shape can be passed either as a
|
|
|
|
tuple or as separate arguments.
|
|
|
|
|
|
|
|
See :func:`reshape` for full documentation.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"squeeze",
|
|
|
|
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) {
|
|
|
|
if (std::holds_alternative<std::monostate>(v)) {
|
|
|
|
return squeeze(a, s);
|
|
|
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
|
|
return squeeze(a, *pv, s);
|
|
|
|
} else {
|
|
|
|
return squeeze(a, std::get<std::vector<int>>(v), s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
See :func:`squeeze`.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"abs",
|
|
|
|
&mlx::core::abs,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`abs`.")
|
|
|
|
.def(
|
|
|
|
"square",
|
|
|
|
&square,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`square`.")
|
|
|
|
.def(
|
|
|
|
"sqrt",
|
|
|
|
&mlx::core::sqrt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`sqrt`.")
|
|
|
|
.def(
|
|
|
|
"rsqrt",
|
|
|
|
&rsqrt,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`rsqrt`.")
|
|
|
|
.def(
|
|
|
|
"reciprocal",
|
|
|
|
&reciprocal,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`reciprocal`.")
|
|
|
|
.def(
|
|
|
|
"exp",
|
|
|
|
&mlx::core::exp,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`exp`.")
|
|
|
|
.def(
|
|
|
|
"log",
|
|
|
|
&mlx::core::log,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`log`.")
|
|
|
|
.def(
|
|
|
|
"log2",
|
|
|
|
&mlx::core::log2,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`log2`.")
|
|
|
|
.def(
|
|
|
|
"log10",
|
|
|
|
&mlx::core::log10,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`log10`.")
|
|
|
|
.def(
|
|
|
|
"sin",
|
|
|
|
&mlx::core::sin,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`sin`.")
|
|
|
|
.def(
|
|
|
|
"cos",
|
|
|
|
&mlx::core::cos,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`cos`.")
|
|
|
|
.def(
|
|
|
|
"log1p",
|
|
|
|
&mlx::core::log1p,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`log1p`.")
|
|
|
|
.def(
|
|
|
|
"all",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`all`.")
|
|
|
|
.def(
|
|
|
|
"any",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`any`.")
|
|
|
|
.def(
|
|
|
|
"transpose",
|
|
|
|
[](const array& a, py::args axes, StreamOrDevice s) {
|
|
|
|
if (axes.size() > 0) {
|
|
|
|
if (axes.size() == 1) {
|
|
|
|
py::object arg = axes[0];
|
|
|
|
if (!py::isinstance<py::int_>(arg)) {
|
|
|
|
return transpose(a, py::cast<std::vector<int>>(arg), s);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return transpose(a, py::cast<std::vector<int>>(axes), s);
|
|
|
|
} else {
|
|
|
|
return transpose(a, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
R"pbdoc(
|
|
|
|
Equivalent to :func:`transpose` but the axes can be passed either as
|
|
|
|
a tuple or as separate arguments.
|
|
|
|
|
|
|
|
See :func:`transpose` for full documentation.
|
|
|
|
)pbdoc")
|
|
|
|
.def_property_readonly(
|
|
|
|
"T",
|
|
|
|
[](const array& a) { return transpose(a); },
|
|
|
|
"Equivalent to calling ``self.transpose()`` with no arguments.")
|
|
|
|
.def(
|
|
|
|
"sum",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`sum`.")
|
|
|
|
.def(
|
|
|
|
"prod",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`prod`.")
|
|
|
|
.def(
|
|
|
|
"min",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`min`.")
|
|
|
|
.def(
|
|
|
|
"max",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`max`.")
|
|
|
|
.def(
|
|
|
|
"logsumexp",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`logsumexp`.")
|
|
|
|
.def(
|
|
|
|
"mean",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`mean`.")
|
|
|
|
.def(
|
|
|
|
"var",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
int ddof,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
"ddof"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`var`.")
|
|
|
|
.def(
|
|
|
|
"split",
|
|
|
|
[](const array& a,
|
|
|
|
const std::variant<int, std::vector<int>>& indices_or_sections,
|
|
|
|
int axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
|
|
|
return split(a, *pv, axis, s);
|
|
|
|
} else {
|
|
|
|
return split(
|
|
|
|
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"indices_or_sections"_a,
|
|
|
|
"axis"_a = 0,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`split`.")
|
|
|
|
.def(
|
|
|
|
"argmin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmin(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmin(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`argmin`.")
|
|
|
|
.def(
|
|
|
|
"argmax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmax(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmax(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
"keepdims"_a = false,
|
|
|
|
py::kw_only(),
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`argmax`.")
|
|
|
|
.def(
|
|
|
|
"cumsum",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumsum(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`cumsum`.")
|
|
|
|
.def(
|
|
|
|
"cumprod",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumprod(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`cumprod`.")
|
|
|
|
.def(
|
|
|
|
"cummax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummax(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`cummax`.")
|
|
|
|
.def(
|
|
|
|
"cummin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummin(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = none,
|
|
|
|
py::kw_only(),
|
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
|
|
|
"stream"_a = none,
|
|
|
|
"See :func:`cummin`.");
|
|
|
|
}
|