2024-03-02 22:09:29 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-11-30 02:42:59 +08:00
|
|
|
#include <cstdint>
|
|
|
|
#include <cstring>
|
|
|
|
#include <sstream>
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <nanobind/ndarray.h>
|
|
|
|
#include <nanobind/stl/complex.h>
|
|
|
|
#include <nanobind/stl/optional.h>
|
|
|
|
#include <nanobind/stl/string.h>
|
|
|
|
#include <nanobind/stl/variant.h>
|
|
|
|
#include <nanobind/stl/vector.h>
|
2023-11-30 02:42:59 +08:00
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
#include "python/src/buffer.h"
|
|
|
|
#include "python/src/convert.h"
|
2023-11-30 02:42:59 +08:00
|
|
|
#include "python/src/indexing.h"
|
|
|
|
#include "python/src/utils.h"
|
|
|
|
|
|
|
|
#include "mlx/ops.h"
|
|
|
|
#include "mlx/transforms.h"
|
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nb = nanobind;
|
|
|
|
using namespace nb::literals;
|
|
|
|
using namespace mlx::core;
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
enum PyScalarT {
|
|
|
|
pybool = 0,
|
|
|
|
pyint = 1,
|
|
|
|
pyfloat = 2,
|
|
|
|
pycomplex = 3,
|
|
|
|
};
|
|
|
|
|
2024-02-20 01:44:27 +08:00
|
|
|
template <typename T, typename U = T>
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::list to_list(array& a, size_t index, int dim) {
|
|
|
|
nb::list pl;
|
2023-11-30 02:42:59 +08:00
|
|
|
auto stride = a.strides()[dim];
|
|
|
|
for (int i = 0; i < a.shape(dim); ++i) {
|
|
|
|
if (dim == a.ndim() - 1) {
|
2024-02-20 01:44:27 +08:00
|
|
|
pl.append(static_cast<U>(a.data<T>()[index]));
|
2023-11-30 02:42:59 +08:00
|
|
|
} else {
|
2024-02-20 01:44:27 +08:00
|
|
|
pl.append(to_list<T, U>(a, index, dim + 1));
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
index += stride;
|
|
|
|
}
|
|
|
|
return pl;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto to_scalar(array& a) {
|
2024-01-27 14:03:52 +08:00
|
|
|
{
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::gil_scoped_release nogil;
|
2024-01-27 14:03:52 +08:00
|
|
|
a.eval();
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
switch (a.dtype()) {
|
|
|
|
case bool_:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<bool>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case uint8:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<uint8_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case uint16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<uint16_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case uint32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<uint32_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case uint64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<uint64_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case int8:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<int8_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case int16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<int16_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case int32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<int32_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case int64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<int64_t>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case float16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(static_cast<float>(a.item<float16_t>()));
|
2023-11-30 02:42:59 +08:00
|
|
|
case float32:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<float>());
|
2023-11-30 02:42:59 +08:00
|
|
|
case bfloat16:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
2023-11-30 02:42:59 +08:00
|
|
|
case complex64:
|
2024-03-19 11:12:25 +08:00
|
|
|
return nb::cast(a.item<std::complex<float>>());
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::object tolist(array& a) {
|
2023-11-30 02:42:59 +08:00
|
|
|
if (a.ndim() == 0) {
|
|
|
|
return to_scalar(a);
|
|
|
|
}
|
2024-01-27 14:03:52 +08:00
|
|
|
{
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::gil_scoped_release nogil;
|
2024-01-27 14:03:52 +08:00
|
|
|
a.eval();
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
switch (a.dtype()) {
|
|
|
|
case bool_:
|
|
|
|
return to_list<bool>(a, 0, 0);
|
|
|
|
case uint8:
|
|
|
|
return to_list<uint8_t>(a, 0, 0);
|
|
|
|
case uint16:
|
|
|
|
return to_list<uint16_t>(a, 0, 0);
|
|
|
|
case uint32:
|
|
|
|
return to_list<uint32_t>(a, 0, 0);
|
|
|
|
case uint64:
|
|
|
|
return to_list<uint64_t>(a, 0, 0);
|
|
|
|
case int8:
|
|
|
|
return to_list<int8_t>(a, 0, 0);
|
|
|
|
case int16:
|
|
|
|
return to_list<int16_t>(a, 0, 0);
|
|
|
|
case int32:
|
|
|
|
return to_list<int32_t>(a, 0, 0);
|
|
|
|
case int64:
|
|
|
|
return to_list<int64_t>(a, 0, 0);
|
|
|
|
case float16:
|
2024-02-20 01:44:27 +08:00
|
|
|
return to_list<float16_t, float>(a, 0, 0);
|
2023-11-30 02:42:59 +08:00
|
|
|
case float32:
|
|
|
|
return to_list<float>(a, 0, 0);
|
|
|
|
case bfloat16:
|
2024-02-20 01:44:27 +08:00
|
|
|
return to_list<bfloat16_t, float>(a, 0, 0);
|
2023-11-30 02:42:59 +08:00
|
|
|
case complex64:
|
|
|
|
return to_list<std::complex<float>>(a, 0, 0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, typename U>
|
|
|
|
void fill_vector(T list, std::vector<U>& vals) {
|
|
|
|
for (auto l : list) {
|
2024-03-19 11:12:25 +08:00
|
|
|
if (nb::isinstance<nb::list>(l)) {
|
|
|
|
fill_vector(nb::cast<nb::list>(l), vals);
|
|
|
|
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
|
|
|
fill_vector(nb::cast<nb::tuple>(l), vals);
|
2023-11-30 02:42:59 +08:00
|
|
|
} else {
|
2024-03-19 11:12:25 +08:00
|
|
|
vals.push_back(nb::cast<U>(l));
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
2024-01-05 10:53:33 +08:00
|
|
|
PyScalarT validate_shape(
|
|
|
|
T list,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
int idx,
|
|
|
|
bool& all_python_primitive_elements) {
|
2023-11-30 02:42:59 +08:00
|
|
|
if (idx >= shape.size()) {
|
|
|
|
throw std::invalid_argument("Initialization encountered extra dimension.");
|
|
|
|
}
|
|
|
|
auto s = shape[idx];
|
2024-03-19 11:12:25 +08:00
|
|
|
if (nb::len(list) != s) {
|
2023-11-30 02:42:59 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Initialization encountered non-uniform length.");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (s == 0) {
|
|
|
|
return pyfloat;
|
|
|
|
}
|
|
|
|
|
|
|
|
PyScalarT type = pybool;
|
|
|
|
for (auto l : list) {
|
|
|
|
PyScalarT t;
|
2024-03-19 11:12:25 +08:00
|
|
|
if (nb::isinstance<nb::list>(l)) {
|
2024-01-05 10:53:33 +08:00
|
|
|
t = validate_shape(
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);
|
|
|
|
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
2024-01-05 10:53:33 +08:00
|
|
|
t = validate_shape(
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::cast<nb::tuple>(l),
|
2024-01-05 10:53:33 +08:00
|
|
|
shape,
|
|
|
|
idx + 1,
|
|
|
|
all_python_primitive_elements);
|
2024-03-19 11:12:25 +08:00
|
|
|
} else if (nb::isinstance<array>(l)) {
|
2024-01-05 10:53:33 +08:00
|
|
|
all_python_primitive_elements = false;
|
2024-03-19 11:12:25 +08:00
|
|
|
auto arr = nb::cast<array>(l);
|
2024-01-05 10:53:33 +08:00
|
|
|
if (arr.ndim() + idx + 1 == shape.size() &&
|
|
|
|
std::equal(
|
|
|
|
arr.shape().cbegin(),
|
|
|
|
arr.shape().cend(),
|
|
|
|
shape.cbegin() + idx + 1)) {
|
|
|
|
t = pybool;
|
|
|
|
} else {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Initialization encountered non-uniform length.");
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
} else {
|
2024-04-01 21:27:52 +08:00
|
|
|
if (nb::isinstance<nb::bool_>(l)) {
|
|
|
|
t = pybool;
|
|
|
|
} else if (nb::isinstance<nb::int_>(l)) {
|
|
|
|
t = pyint;
|
|
|
|
} else if (nb::isinstance<nb::float_>(l)) {
|
|
|
|
t = pyfloat;
|
|
|
|
} else if (PyComplex_Check(l.ptr())) {
|
|
|
|
t = pycomplex;
|
|
|
|
} else {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
|
|
|
<< " received in array initialization.";
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
if (idx + 1 != shape.size()) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Initialization encountered non-uniform length.");
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
type = std::max(type, t);
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void get_shape(T list, std::vector<int>& shape) {
|
2024-03-26 04:29:45 +08:00
|
|
|
shape.push_back(check_shape_dim(nb::len(list)));
|
2023-11-30 02:42:59 +08:00
|
|
|
if (shape.back() > 0) {
|
2024-03-19 11:12:25 +08:00
|
|
|
auto l = list.begin();
|
|
|
|
if (nb::isinstance<nb::list>(*l)) {
|
|
|
|
return get_shape(nb::cast<nb::list>(*l), shape);
|
|
|
|
} else if (nb::isinstance<nb::tuple>(*l)) {
|
|
|
|
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
|
|
|
} else if (nb::isinstance<array>(*l)) {
|
|
|
|
auto arr = nb::cast<array>(*l);
|
2024-03-26 04:29:45 +08:00
|
|
|
for (int i = 0; i < arr.ndim(); i++) {
|
|
|
|
shape.push_back(check_shape_dim(arr.shape(i)));
|
|
|
|
}
|
2024-01-05 10:53:33 +08:00
|
|
|
return;
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
using ArrayInitType = std::variant<
|
|
|
|
nb::bool_,
|
|
|
|
nb::int_,
|
|
|
|
nb::float_,
|
|
|
|
// Must be above ndarray
|
2024-01-06 10:17:44 +08:00
|
|
|
array,
|
2024-03-19 11:12:25 +08:00
|
|
|
// Must be above complex
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
|
|
|
std::complex<float>,
|
|
|
|
nb::list,
|
|
|
|
nb::tuple,
|
|
|
|
nb::object>;
|
2023-11-30 02:42:59 +08:00
|
|
|
|
2024-01-05 10:53:33 +08:00
|
|
|
// Forward declaration
|
2024-03-19 11:12:25 +08:00
|
|
|
array create_array(ArrayInitType v, std::optional<Dtype> t);
|
2023-11-30 02:42:59 +08:00
|
|
|
|
2024-01-05 10:53:33 +08:00
|
|
|
template <typename T>
|
|
|
|
array array_from_list(
|
|
|
|
T pl,
|
|
|
|
const PyScalarT& inferred_type,
|
|
|
|
std::optional<Dtype> specified_type,
|
|
|
|
const std::vector<int>& shape) {
|
2023-11-30 02:42:59 +08:00
|
|
|
// Make the array
|
2024-01-05 10:53:33 +08:00
|
|
|
switch (inferred_type) {
|
2023-11-30 02:42:59 +08:00
|
|
|
case pybool: {
|
|
|
|
std::vector<bool> vals;
|
|
|
|
fill_vector(pl, vals);
|
2024-01-05 10:53:33 +08:00
|
|
|
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
case pyint: {
|
2024-01-20 07:49:25 +08:00
|
|
|
auto dtype = specified_type.value_or(int32);
|
|
|
|
if (dtype == int64) {
|
|
|
|
std::vector<int64_t> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype);
|
|
|
|
} else if (dtype == uint64) {
|
|
|
|
std::vector<uint64_t> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype);
|
|
|
|
} else if (dtype == uint32) {
|
|
|
|
std::vector<uint32_t> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype);
|
2024-03-26 03:32:59 +08:00
|
|
|
} else if (issubdtype(dtype, inexact)) {
|
2024-01-20 07:49:25 +08:00
|
|
|
std::vector<float> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype);
|
|
|
|
} else {
|
|
|
|
std::vector<int> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(vals.begin(), shape, dtype);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
case pyfloat: {
|
|
|
|
std::vector<float> vals;
|
|
|
|
fill_vector(pl, vals);
|
2024-01-05 10:53:33 +08:00
|
|
|
return array(vals.begin(), shape, specified_type.value_or(float32));
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
case pycomplex: {
|
|
|
|
std::vector<std::complex<float>> vals;
|
|
|
|
fill_vector(pl, vals);
|
|
|
|
return array(
|
|
|
|
reinterpret_cast<complex64_t*>(vals.data()),
|
|
|
|
shape,
|
2024-01-05 10:53:33 +08:00
|
|
|
specified_type.value_or(complex64));
|
|
|
|
}
|
|
|
|
default: {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Should not happen, inferred: " << inferred_type
|
|
|
|
<< " on subarray made of only python primitive types.";
|
|
|
|
throw std::runtime_error(msg.str());
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-05 10:53:33 +08:00
|
|
|
template <typename T>
|
|
|
|
array array_from_list(T pl, std::optional<Dtype> dtype) {
|
|
|
|
// Compute the shape
|
|
|
|
std::vector<int> shape;
|
|
|
|
get_shape(pl, shape);
|
|
|
|
|
|
|
|
// Validate the shape and type
|
|
|
|
bool all_python_primitive_elements = true;
|
|
|
|
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
|
|
|
|
|
|
|
|
if (all_python_primitive_elements) {
|
|
|
|
// `pl` does not contain mlx arrays
|
|
|
|
return array_from_list(pl, type, dtype, shape);
|
|
|
|
}
|
|
|
|
|
|
|
|
// `pl` contains mlx arrays
|
|
|
|
std::vector<array> arrays;
|
|
|
|
for (auto l : pl) {
|
2024-03-19 11:12:25 +08:00
|
|
|
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
2024-01-05 10:53:33 +08:00
|
|
|
}
|
|
|
|
return stack(arrays);
|
|
|
|
}
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Module
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
|
|
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
|
|
|
return array(nb::cast<bool>(*pv), t.value_or(bool_));
|
|
|
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
|
|
|
return array(nb::cast<int>(*pv), t.value_or(int32));
|
|
|
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
|
|
|
return array(nb::cast<float>(*pv), t.value_or(float32));
|
2024-01-05 10:53:33 +08:00
|
|
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
|
|
|
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
2024-03-19 11:12:25 +08:00
|
|
|
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
2024-01-05 10:53:33 +08:00
|
|
|
return array_from_list(*pv, t);
|
2024-03-19 11:12:25 +08:00
|
|
|
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
2024-01-05 10:53:33 +08:00
|
|
|
return array_from_list(*pv, t);
|
2024-03-19 11:12:25 +08:00
|
|
|
} else if (auto pv = std::get_if<
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
|
|
|
pv) {
|
|
|
|
return nd_array_to_mlx(*pv, t);
|
2024-01-06 10:17:44 +08:00
|
|
|
} else if (auto pv = std::get_if<array>(&v); pv) {
|
|
|
|
return astype(*pv, t.value_or((*pv).dtype()));
|
2024-01-05 10:53:33 +08:00
|
|
|
} else {
|
2024-03-19 11:12:25 +08:00
|
|
|
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
2024-01-05 10:53:33 +08:00
|
|
|
return astype(arr, t.value_or(arr.dtype()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-10 05:36:51 +08:00
|
|
|
class ArrayAt {
|
|
|
|
public:
|
|
|
|
ArrayAt(array x) : x_(std::move(x)) {}
|
2024-03-19 11:12:25 +08:00
|
|
|
ArrayAt& set_indices(nb::object indices) {
|
2024-01-10 05:36:51 +08:00
|
|
|
indices_ = indices;
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
array add(const ScalarOrArray& v) {
|
|
|
|
return mlx_add_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
array subtract(const ScalarOrArray& v) {
|
|
|
|
return mlx_subtract_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
array multiply(const ScalarOrArray& v) {
|
|
|
|
return mlx_multiply_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
array divide(const ScalarOrArray& v) {
|
|
|
|
return mlx_divide_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
array maximum(const ScalarOrArray& v) {
|
|
|
|
return mlx_maximum_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
array minimum(const ScalarOrArray& v) {
|
|
|
|
return mlx_minimum_item(x_, indices_, v);
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
array x_;
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::object indices_;
|
2024-01-10 05:36:51 +08:00
|
|
|
};
|
|
|
|
|
2024-01-18 21:50:25 +08:00
|
|
|
class ArrayPythonIterator {
|
|
|
|
public:
|
|
|
|
ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) {
|
|
|
|
if (x_.shape(0) > 0 && x_.shape(0) < 10) {
|
|
|
|
splits_ = split(x_, x_.shape(0));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
array next() {
|
|
|
|
if (idx_ >= x_.shape(0)) {
|
2024-03-19 11:12:25 +08:00
|
|
|
throw nb::stop_iteration();
|
2024-01-18 21:50:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (idx_ >= 0 && idx_ < splits_.size()) {
|
|
|
|
return squeeze(splits_[idx_++], 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
return *(x_.begin() + idx_++);
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
int idx_;
|
|
|
|
array x_;
|
|
|
|
std::vector<array> splits_;
|
|
|
|
};
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
void init_array(nb::module_& m) {
|
2024-01-18 23:49:41 +08:00
|
|
|
// Set Python print formatting options
|
|
|
|
mlx::core::global_formatter.capitalize_bool = true;
|
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
// Types
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::class_<Dtype>(
|
2023-11-30 02:42:59 +08:00
|
|
|
m,
|
|
|
|
"Dtype",
|
|
|
|
R"pbdoc(
|
|
|
|
An object to hold the type of a :class:`array`.
|
|
|
|
|
|
|
|
See the :ref:`list of types <data_types>` for more details
|
|
|
|
on available data types.
|
|
|
|
)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc")
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__repr__",
|
|
|
|
[](const Dtype& t) {
|
|
|
|
std::ostringstream os;
|
2023-12-10 01:35:28 +08:00
|
|
|
os << "mlx.core.";
|
2023-11-30 02:42:59 +08:00
|
|
|
os << t;
|
|
|
|
return os.str();
|
|
|
|
})
|
2024-03-19 11:12:25 +08:00
|
|
|
.def(
|
|
|
|
"__eq__",
|
|
|
|
[](const Dtype& t, const nb::object& other) {
|
|
|
|
return nb::isinstance<Dtype>(other) && t == nb::cast<Dtype>(other);
|
|
|
|
})
|
2023-12-10 01:35:28 +08:00
|
|
|
.def("__hash__", [](const Dtype& t) {
|
|
|
|
return static_cast<int64_t>(t.val);
|
|
|
|
});
|
2024-03-19 11:12:25 +08:00
|
|
|
m.attr("bool_") = nb::cast(bool_);
|
|
|
|
m.attr("uint8") = nb::cast(uint8);
|
|
|
|
m.attr("uint16") = nb::cast(uint16);
|
|
|
|
m.attr("uint32") = nb::cast(uint32);
|
|
|
|
m.attr("uint64") = nb::cast(uint64);
|
|
|
|
m.attr("int8") = nb::cast(int8);
|
|
|
|
m.attr("int16") = nb::cast(int16);
|
|
|
|
m.attr("int32") = nb::cast(int32);
|
|
|
|
m.attr("int64") = nb::cast(int64);
|
|
|
|
m.attr("float16") = nb::cast(float16);
|
|
|
|
m.attr("float32") = nb::cast(float32);
|
|
|
|
m.attr("bfloat16") = nb::cast(bfloat16);
|
|
|
|
m.attr("complex64") = nb::cast(complex64);
|
2024-03-26 03:32:59 +08:00
|
|
|
nb::class_<Dtype::Category>(
|
|
|
|
m,
|
|
|
|
"DtypeCategory",
|
|
|
|
R"pbdoc(
|
|
|
|
Type to hold categories of :class:`dtypes <Dtype>`.
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.generic`
|
|
|
|
|
|
|
|
* :ref:`bool_ <data_types>`
|
|
|
|
* :attr:`~mlx.core.number`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.integer`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.unsignedinteger`
|
|
|
|
|
|
|
|
* :ref:`uint8 <data_types>`
|
|
|
|
* :ref:`uint16 <data_types>`
|
|
|
|
* :ref:`uint32 <data_types>`
|
|
|
|
* :ref:`uint64 <data_types>`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.signedinteger`
|
|
|
|
|
|
|
|
* :ref:`int8 <data_types>`
|
|
|
|
* :ref:`int32 <data_types>`
|
|
|
|
* :ref:`int64 <data_types>`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.inexact`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.floating`
|
|
|
|
|
|
|
|
* :ref:`float16 <data_types>`
|
|
|
|
* :ref:`bfloat16 <data_types>`
|
|
|
|
* :ref:`float32 <data_types>`
|
|
|
|
|
|
|
|
* :attr:`~mlx.core.complexfloating`
|
|
|
|
|
|
|
|
* :ref:`complex128 <data_types>`
|
|
|
|
|
|
|
|
See also :func:`~mlx.core.issubdtype`.
|
|
|
|
)pbdoc");
|
|
|
|
m.attr("complexfloating") = nb::cast(complexfloating);
|
|
|
|
m.attr("floating") = nb::cast(floating);
|
|
|
|
m.attr("inexact") = nb::cast(inexact);
|
|
|
|
m.attr("signedinteger") = nb::cast(signedinteger);
|
|
|
|
m.attr("unsignedinteger") = nb::cast(unsignedinteger);
|
|
|
|
m.attr("integer") = nb::cast(integer);
|
|
|
|
m.attr("number") = nb::cast(number);
|
|
|
|
m.attr("generic") = nb::cast(generic);
|
2024-03-19 11:12:25 +08:00
|
|
|
|
|
|
|
nb::class_<ArrayAt>(
|
2024-01-10 05:36:51 +08:00
|
|
|
m,
|
|
|
|
"_ArrayAt",
|
|
|
|
R"pbdoc(
|
|
|
|
A helper object to apply updates at specific indices.
|
2024-03-19 11:12:25 +08:00
|
|
|
)pbdoc")
|
2024-01-15 06:06:16 +08:00
|
|
|
.def(
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::init<const array&>(),
|
2024-01-15 06:06:16 +08:00
|
|
|
"x"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::sig("def __init__(self, x: array)"))
|
|
|
|
.def("__getitem__", &ArrayAt::set_indices, "indices"_a.none())
|
2024-01-15 06:06:16 +08:00
|
|
|
.def("add", &ArrayAt::add, "value"_a)
|
|
|
|
.def("subtract", &ArrayAt::subtract, "value"_a)
|
|
|
|
.def("multiply", &ArrayAt::multiply, "value"_a)
|
|
|
|
.def("divide", &ArrayAt::divide, "value"_a)
|
|
|
|
.def("maximum", &ArrayAt::maximum, "value"_a)
|
|
|
|
.def("minimum", &ArrayAt::minimum, "value"_a);
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::class_<ArrayPythonIterator>(
|
|
|
|
m,
|
|
|
|
"_ArrayIterator",
|
|
|
|
R"pbdoc(
|
|
|
|
A helper object to iterate over the 1st dimension of an array.
|
|
|
|
)pbdoc")
|
2024-01-18 21:50:25 +08:00
|
|
|
.def(
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::init<const array&>(),
|
2024-01-18 21:50:25 +08:00
|
|
|
"x"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::sig("def __init__(self, x: array)"))
|
2024-01-18 21:50:25 +08:00
|
|
|
.def("__next__", &ArrayPythonIterator::next)
|
|
|
|
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
// Install buffer protocol functions
|
|
|
|
PyType_Slot array_slots[] = {
|
|
|
|
{Py_bf_getbuffer, (void*)getbuffer},
|
|
|
|
{Py_bf_releasebuffer, (void*)releasebuffer},
|
|
|
|
{0, nullptr}};
|
|
|
|
|
|
|
|
nb::class_<array>(
|
|
|
|
m,
|
|
|
|
"array",
|
|
|
|
R"pbdoc(An N-dimensional array object.)pbdoc",
|
|
|
|
nb::type_slots(array_slots),
|
|
|
|
nb::is_weak_referenceable())
|
|
|
|
.def(
|
|
|
|
"__init__",
|
|
|
|
[](array* aptr, ArrayInitType v, std::optional<Dtype> t) {
|
|
|
|
new (aptr) array(create_array(v, t));
|
|
|
|
},
|
|
|
|
"val"_a,
|
|
|
|
"dtype"_a = nb::none(),
|
|
|
|
nb::sig(
|
|
|
|
"def __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)"))
|
|
|
|
.def_prop_ro(
|
2024-01-02 13:08:17 +08:00
|
|
|
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_prop_ro("ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc")
|
|
|
|
.def_prop_ro(
|
2023-12-26 02:34:28 +08:00
|
|
|
"itemsize",
|
|
|
|
&array::itemsize,
|
|
|
|
R"pbdoc(The size of the array's datatype in bytes.)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_prop_ro(
|
2023-12-26 02:34:28 +08:00
|
|
|
"nbytes",
|
|
|
|
&array::nbytes,
|
|
|
|
R"pbdoc(The number of bytes in the array.)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_prop_ro(
|
2023-11-30 02:42:59 +08:00
|
|
|
"shape",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](const array& a) { return nb::tuple(nb::cast(a.shape())); },
|
2023-11-30 02:42:59 +08:00
|
|
|
R"pbdoc(
|
2024-02-05 01:21:22 +08:00
|
|
|
The shape of the array as a Python tuple.
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-01-31 05:11:01 +08:00
|
|
|
tuple(int): A tuple containing the sizes of each dimension.
|
2023-11-30 02:42:59 +08:00
|
|
|
)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_prop_ro(
|
2023-11-30 02:42:59 +08:00
|
|
|
"dtype",
|
|
|
|
&array::dtype,
|
|
|
|
R"pbdoc(
|
|
|
|
The array's :class:`Dtype`.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"item",
|
|
|
|
&to_scalar,
|
|
|
|
R"pbdoc(
|
|
|
|
Access the value of a scalar array.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Standard Python scalar.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"tolist",
|
|
|
|
&tolist,
|
|
|
|
R"pbdoc(
|
|
|
|
Convert the array to a Python :class:`list`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: The Python list.
|
|
|
|
|
|
|
|
If the array is a scalar then a standard Python scalar is returned.
|
|
|
|
|
|
|
|
If the array has more than one dimension then the result is a nested
|
|
|
|
list of lists.
|
|
|
|
|
2024-01-02 13:08:17 +08:00
|
|
|
The value type of the list corresponding to the last dimension is either
|
2023-11-30 02:42:59 +08:00
|
|
|
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"astype",
|
|
|
|
&astype,
|
|
|
|
"dtype"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Cast the array to a specified type.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dtype (Dtype): Type to which the array is cast.
|
|
|
|
stream (Stream): Stream (or device) for the operation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array: The array with type ``dtype``.
|
|
|
|
)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def("__getitem__", mlx_get_item, nb::arg().none())
|
|
|
|
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
|
|
|
|
.def_prop_ro(
|
2024-01-10 05:36:51 +08:00
|
|
|
"at",
|
|
|
|
[](const array& a) { return ArrayAt(a); },
|
|
|
|
R"pbdoc(
|
|
|
|
Used to apply updates at the given indices.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
2024-03-25 06:03:27 +08:00
|
|
|
Regular in-place updates map to assignment. For instance ``x[idx] += y``
|
|
|
|
maps to ``x[idx] = x[idx] + y``. As a result, assigning to the
|
|
|
|
same index ignores all but one update. Using ``x.at[idx].add(y)``
|
|
|
|
will correctly apply all updates to all indices.
|
2024-01-10 05:36:51 +08:00
|
|
|
|
|
|
|
.. list-table::
|
|
|
|
:header-rows: 1
|
|
|
|
|
|
|
|
* - array.at syntax
|
|
|
|
- In-place syntax
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].add(y)``
|
2024-01-10 05:36:51 +08:00
|
|
|
- ``x[idx] += y``
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].subtract(y)``
|
2024-01-10 05:36:51 +08:00
|
|
|
- ``x[idx] -= y``
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].multiply(y)``
|
2024-01-10 05:36:51 +08:00
|
|
|
- ``x[idx] *= y``
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].divide(y)``
|
2024-01-10 05:36:51 +08:00
|
|
|
- ``x[idx] /= y``
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].maximum(y)``
|
2024-01-10 05:36:51 +08:00
|
|
|
- ``x[idx] = mx.maximum(x[idx], y)``
|
2024-01-15 06:06:16 +08:00
|
|
|
* - ``x = x.at[idx].minimum(y)``
|
2024-03-25 06:03:27 +08:00
|
|
|
- ``x[idx] = mx.minimum(x[idx], y)``
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> a = mx.array([0, 0])
|
|
|
|
>>> idx = mx.array([0, 1, 0, 1])
|
|
|
|
>>> a[idx] += 1
|
|
|
|
>>> a
|
|
|
|
array([1, 1], dtype=int32)
|
|
|
|
>>>
|
|
|
|
>>> a = mx.array([0, 0])
|
|
|
|
>>> a.at[idx].add(1)
|
|
|
|
array([2, 2], dtype=int32)
|
2024-01-10 05:36:51 +08:00
|
|
|
)pbdoc")
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__len__",
|
|
|
|
[](const array& a) {
|
|
|
|
if (a.ndim() == 0) {
|
2024-03-19 11:12:25 +08:00
|
|
|
throw nb::type_error("len() 0-dimensional array.");
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
|
|
|
return a.shape(0);
|
|
|
|
})
|
2024-01-18 21:50:25 +08:00
|
|
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
2024-03-19 11:12:25 +08:00
|
|
|
.def(
|
|
|
|
"__getstate__",
|
|
|
|
[](const array& a) {
|
2024-03-07 00:02:41 +08:00
|
|
|
if (a.dtype() == bfloat16) {
|
|
|
|
}
|
2024-03-19 11:12:25 +08:00
|
|
|
return mlx_to_np_array(a);
|
|
|
|
})
|
|
|
|
.def(
|
|
|
|
"__setstate__",
|
|
|
|
[](array& arr,
|
|
|
|
const nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>& state) {
|
|
|
|
new (&arr) array(nd_array_to_mlx(state, std::nullopt));
|
|
|
|
})
|
2024-03-07 00:02:41 +08:00
|
|
|
.def("__copy__", [](const array& self) { return array(self); })
|
|
|
|
.def(
|
|
|
|
"__deepcopy__",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](const array& self, nb::dict) { return array(self); },
|
2024-03-07 00:02:41 +08:00
|
|
|
"memo"_a)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__add__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("addition", v);
|
|
|
|
}
|
2024-03-19 11:12:25 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
|
|
|
return add(a, b);
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
|
|
|
"other"_a)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__iadd__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace addition", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(add(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__radd__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("addition", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return add(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__sub__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("subtraction", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return subtract(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__isub__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace subtraction", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(subtract(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__rsub__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("subtraction", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return subtract(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__mul__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("multiplication", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return multiply(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__imul__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace multiplication", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(multiply(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__rmul__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("multiplication", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return multiply(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__truediv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("division", v);
|
|
|
|
}
|
2023-12-12 12:20:58 +08:00
|
|
|
return divide(a, to_array(v, a.dtype()));
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
|
|
|
"other"_a)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__itruediv__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace division", v);
|
|
|
|
}
|
2024-03-26 03:32:59 +08:00
|
|
|
if (!issubdtype(a.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"In place division cannot cast to non-floating point type.");
|
|
|
|
}
|
|
|
|
a.overwrite_descriptor(divide(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__rtruediv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("division", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
return divide(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__div__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("division", v);
|
|
|
|
}
|
2023-12-12 12:20:58 +08:00
|
|
|
return divide(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
2024-01-10 08:05:38 +08:00
|
|
|
"__rdiv__",
|
2023-12-12 12:20:58 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("division", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
return divide(to_array(v, a.dtype()), a);
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
2024-01-10 08:05:38 +08:00
|
|
|
"__floordiv__",
|
2023-11-30 02:42:59 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("floor division", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
return floor_divide(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ifloordiv__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace floor division", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
2023-12-12 12:20:58 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-12-12 12:20:58 +08:00
|
|
|
.def(
|
|
|
|
"__rfloordiv__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("floor division", v);
|
|
|
|
}
|
2023-12-12 12:20:58 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
2023-12-20 12:12:19 +08:00
|
|
|
return floor_divide(b, a);
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
2024-01-10 08:05:38 +08:00
|
|
|
"__mod__",
|
2023-11-30 02:42:59 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("modulus", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
return remainder(a, to_array(v, a.dtype()));
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
|
|
|
"other"_a)
|
2023-12-09 07:08:52 +08:00
|
|
|
.def(
|
2024-01-10 08:05:38 +08:00
|
|
|
"__imod__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace modulus", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(remainder(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
2023-12-09 07:08:52 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-12-09 07:08:52 +08:00
|
|
|
.def(
|
|
|
|
"__rmod__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("modulus", v);
|
|
|
|
}
|
2023-12-09 07:08:52 +08:00
|
|
|
return remainder(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__eq__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a,
|
|
|
|
const ScalarOrArray& v) -> std::variant<array, bool> {
|
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
return false;
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__lt__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) -> array {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("less than", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return less(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__le__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) -> array {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("less than or equal", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return less_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__gt__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) -> array {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("greater than", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return greater(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ge__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a, const ScalarOrArray v) -> array {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("greater than or equal", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return greater_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ne__",
|
2024-03-29 21:52:30 +08:00
|
|
|
[](const array& a,
|
|
|
|
const ScalarOrArray v) -> std::variant<array, bool> {
|
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
return true;
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return not_equal(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def("__neg__", [](const array& a) { return -a; })
|
2024-03-19 11:12:25 +08:00
|
|
|
.def("__bool__", [](array& a) { return nb::bool_(to_scalar(a)); })
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__repr__",
|
|
|
|
[](array& a) {
|
2024-04-17 21:16:02 +08:00
|
|
|
nb::gil_scoped_release nogil;
|
2023-11-30 02:42:59 +08:00
|
|
|
std::ostringstream os;
|
|
|
|
os << a;
|
|
|
|
return os.str();
|
|
|
|
})
|
|
|
|
.def(
|
2024-01-10 08:05:38 +08:00
|
|
|
"__matmul__",
|
|
|
|
[](const array& a, array& other) { return matmul(a, other); },
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__imatmul__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, array& other) -> array& {
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(matmul(a, other));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"__pow__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("power", v);
|
|
|
|
}
|
2023-11-30 02:42:59 +08:00
|
|
|
return power(a, to_array(v, a.dtype()));
|
|
|
|
},
|
|
|
|
"other"_a)
|
2024-02-16 03:26:20 +08:00
|
|
|
.def(
|
|
|
|
"__rpow__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("power", v);
|
|
|
|
}
|
2024-02-16 03:26:20 +08:00
|
|
|
return power(to_array(v, a.dtype()), a);
|
|
|
|
},
|
|
|
|
"other"_a)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__ipow__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace power", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
a.overwrite_descriptor(power(a, to_array(v, a.dtype())));
|
|
|
|
return a;
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2024-01-10 08:05:38 +08:00
|
|
|
.def(
|
|
|
|
"__invert__",
|
|
|
|
[](const array& a) {
|
2024-03-26 03:32:59 +08:00
|
|
|
if (issubdtype(a.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Floating point types not allowed with or bitwise inversion.");
|
|
|
|
}
|
|
|
|
if (a.dtype() != bool_) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Bitwise inversion not yet supported for integer types.");
|
|
|
|
}
|
|
|
|
return logical_not(a);
|
|
|
|
})
|
2024-01-08 23:00:05 +08:00
|
|
|
.def(
|
|
|
|
"__and__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("bitwise and", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
2024-03-26 03:32:59 +08:00
|
|
|
if (issubdtype(a.dtype(), inexact) ||
|
|
|
|
issubdtype(b.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Floating point types not allowed with bitwise and.");
|
|
|
|
}
|
|
|
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Bitwise and not yet supported for integer types.");
|
|
|
|
}
|
|
|
|
return logical_and(a, b);
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__iand__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace bitwise and", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
2024-03-26 03:32:59 +08:00
|
|
|
if (issubdtype(a.dtype(), inexact) ||
|
|
|
|
issubdtype(b.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Floating point types not allowed with bitwise and.");
|
|
|
|
}
|
|
|
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Bitwise and not yet supported for integer types.");
|
|
|
|
}
|
|
|
|
a.overwrite_descriptor(logical_and(a, b));
|
|
|
|
return a;
|
2024-01-08 23:00:05 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2024-01-08 23:00:05 +08:00
|
|
|
.def(
|
|
|
|
"__or__",
|
|
|
|
[](const array& a, const ScalarOrArray v) {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("bitwise or", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
2024-03-26 03:32:59 +08:00
|
|
|
if (issubdtype(a.dtype(), inexact) ||
|
|
|
|
issubdtype(b.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Floating point types not allowed with or bitwise or.");
|
|
|
|
}
|
|
|
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Bitwise or not yet supported for integer types.");
|
|
|
|
}
|
|
|
|
return logical_or(a, b);
|
|
|
|
},
|
|
|
|
"other"_a)
|
|
|
|
.def(
|
|
|
|
"__ior__",
|
2024-03-08 01:34:11 +08:00
|
|
|
[](array& a, const ScalarOrArray v) -> array& {
|
2024-04-03 12:11:24 +08:00
|
|
|
if (!is_comparable_with_array(v)) {
|
|
|
|
throw_invalid_operation("inplace bitwise or", v);
|
|
|
|
}
|
2024-01-10 08:05:38 +08:00
|
|
|
auto b = to_array(v, a.dtype());
|
2024-03-26 03:32:59 +08:00
|
|
|
if (issubdtype(a.dtype(), inexact) ||
|
|
|
|
issubdtype(b.dtype(), inexact)) {
|
2024-01-10 08:05:38 +08:00
|
|
|
throw std::invalid_argument(
|
|
|
|
"Floating point types not allowed with or bitwise or.");
|
|
|
|
}
|
|
|
|
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
|
|
|
throw std::invalid_argument(
|
|
|
|
"Bitwise or not yet supported for integer types.");
|
|
|
|
}
|
|
|
|
a.overwrite_descriptor(logical_or(a, b));
|
|
|
|
return a;
|
2024-01-08 23:00:05 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"other"_a,
|
|
|
|
nb::rv_policy::none)
|
2023-12-17 13:54:37 +08:00
|
|
|
.def(
|
|
|
|
"flatten",
|
|
|
|
[](const array& a,
|
|
|
|
int start_axis,
|
|
|
|
int end_axis,
|
|
|
|
const StreamOrDevice& s) {
|
|
|
|
return flatten(a, start_axis, end_axis);
|
|
|
|
},
|
|
|
|
"start_axis"_a = 0,
|
|
|
|
"end_axis"_a = -1,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-12-17 13:54:37 +08:00
|
|
|
R"pbdoc(
|
|
|
|
See :func:`flatten`.
|
|
|
|
)pbdoc")
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"reshape",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](const array& a, nb::args shape_, StreamOrDevice s) {
|
|
|
|
std::vector<int> shape;
|
|
|
|
if (!nb::isinstance<int>(shape_[0])) {
|
|
|
|
shape = nb::cast<std::vector<int>>(shape_[0]);
|
|
|
|
} else {
|
|
|
|
shape = nb::cast<std::vector<int>>(shape_);
|
2023-11-30 02:42:59 +08:00
|
|
|
}
|
2024-03-19 11:12:25 +08:00
|
|
|
return reshape(a, shape, s);
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"shape"_a,
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Equivalent to :func:`reshape` but the shape can be passed either as a
|
2024-03-19 11:12:25 +08:00
|
|
|
:obj:`tuple` or as separate arguments.
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
See :func:`reshape` for full documentation.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"squeeze",
|
|
|
|
[](const array& a, const IntOrVec& v, const StreamOrDevice& s) {
|
|
|
|
if (std::holds_alternative<std::monostate>(v)) {
|
|
|
|
return squeeze(a, s);
|
|
|
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
|
|
return squeeze(a, *pv, s);
|
|
|
|
} else {
|
|
|
|
return squeeze(a, std::get<std::vector<int>>(v), s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
R"pbdoc(
|
|
|
|
See :func:`squeeze`.
|
|
|
|
)pbdoc")
|
|
|
|
.def(
|
|
|
|
"abs",
|
|
|
|
&mlx::core::abs,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`abs`.")
|
2024-02-05 08:18:03 +08:00
|
|
|
.def(
|
|
|
|
"__abs__", [](const array& a) { return abs(a); }, "See :func:`abs`.")
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"square",
|
|
|
|
&square,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`square`.")
|
|
|
|
.def(
|
|
|
|
"sqrt",
|
|
|
|
&mlx::core::sqrt,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`sqrt`.")
|
|
|
|
.def(
|
|
|
|
"rsqrt",
|
|
|
|
&rsqrt,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`rsqrt`.")
|
|
|
|
.def(
|
|
|
|
"reciprocal",
|
|
|
|
&reciprocal,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`reciprocal`.")
|
|
|
|
.def(
|
|
|
|
"exp",
|
|
|
|
&mlx::core::exp,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`exp`.")
|
|
|
|
.def(
|
|
|
|
"log",
|
|
|
|
&mlx::core::log,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`log`.")
|
|
|
|
.def(
|
|
|
|
"log2",
|
|
|
|
&mlx::core::log2,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`log2`.")
|
|
|
|
.def(
|
|
|
|
"log10",
|
|
|
|
&mlx::core::log10,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`log10`.")
|
|
|
|
.def(
|
|
|
|
"sin",
|
|
|
|
&mlx::core::sin,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`sin`.")
|
|
|
|
.def(
|
|
|
|
"cos",
|
|
|
|
&mlx::core::cos,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`cos`.")
|
|
|
|
.def(
|
|
|
|
"log1p",
|
|
|
|
&mlx::core::log1p,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`log1p`.")
|
|
|
|
.def(
|
|
|
|
"all",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`all`.")
|
|
|
|
.def(
|
|
|
|
"any",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`any`.")
|
2023-12-15 04:59:12 +08:00
|
|
|
.def(
|
|
|
|
"moveaxis",
|
|
|
|
&moveaxis,
|
|
|
|
"source"_a,
|
|
|
|
"destination"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-12-15 04:59:12 +08:00
|
|
|
"See :func:`moveaxis`.")
|
|
|
|
.def(
|
|
|
|
"swapaxes",
|
|
|
|
&swapaxes,
|
|
|
|
"axis1"_a,
|
|
|
|
"axis2"_a,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-01-13 03:06:16 +08:00
|
|
|
"See :func:`swapaxes`.")
|
2023-11-30 02:42:59 +08:00
|
|
|
.def(
|
|
|
|
"transpose",
|
2024-03-19 11:12:25 +08:00
|
|
|
[](const array& a, nb::args axes_, StreamOrDevice s) {
|
|
|
|
if (axes_.size() == 0) {
|
2023-11-30 02:42:59 +08:00
|
|
|
return transpose(a, s);
|
|
|
|
}
|
2024-03-19 11:12:25 +08:00
|
|
|
std::vector<int> axes;
|
|
|
|
if (!nb::isinstance<int>(axes_[0])) {
|
|
|
|
axes = nb::cast<std::vector<int>>(axes_[0]);
|
|
|
|
} else {
|
|
|
|
axes = nb::cast<std::vector<int>>(axes_);
|
|
|
|
}
|
|
|
|
return transpose(a, axes, s);
|
2023-11-30 02:42:59 +08:00
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axes"_a,
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Equivalent to :func:`transpose` but the axes can be passed either as
|
|
|
|
a tuple or as separate arguments.
|
|
|
|
|
|
|
|
See :func:`transpose` for full documentation.
|
|
|
|
)pbdoc")
|
2024-03-19 11:12:25 +08:00
|
|
|
.def_prop_ro(
|
2023-11-30 02:42:59 +08:00
|
|
|
"T",
|
|
|
|
[](const array& a) { return transpose(a); },
|
|
|
|
"Equivalent to calling ``self.transpose()`` with no arguments.")
|
|
|
|
.def(
|
|
|
|
"sum",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`sum`.")
|
|
|
|
.def(
|
|
|
|
"prod",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`prod`.")
|
|
|
|
.def(
|
|
|
|
"min",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`min`.")
|
|
|
|
.def(
|
|
|
|
"max",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`max`.")
|
|
|
|
.def(
|
|
|
|
"logsumexp",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`logsumexp`.")
|
|
|
|
.def(
|
|
|
|
"mean",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`mean`.")
|
|
|
|
.def(
|
|
|
|
"var",
|
|
|
|
[](const array& a,
|
|
|
|
const IntOrVec& axis,
|
|
|
|
bool keepdims,
|
|
|
|
int ddof,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
|
|
|
"ddof"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`var`.")
|
|
|
|
.def(
|
|
|
|
"split",
|
|
|
|
[](const array& a,
|
|
|
|
const std::variant<int, std::vector<int>>& indices_or_sections,
|
|
|
|
int axis,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
|
|
|
return split(a, *pv, axis, s);
|
|
|
|
} else {
|
|
|
|
return split(
|
|
|
|
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"indices_or_sections"_a,
|
|
|
|
"axis"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`split`.")
|
|
|
|
.def(
|
|
|
|
"argmin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmin(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmin(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"axis"_a = std::nullopt,
|
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`argmin`.")
|
|
|
|
.def(
|
|
|
|
"argmax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool keepdims,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return argmax(a, *axis, keepdims, s);
|
|
|
|
} else {
|
|
|
|
return argmax(a, keepdims, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"keepdims"_a = false,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`argmax`.")
|
|
|
|
.def(
|
|
|
|
"cumsum",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumsum(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`cumsum`.")
|
|
|
|
.def(
|
|
|
|
"cumprod",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cumprod(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`cumprod`.")
|
|
|
|
.def(
|
|
|
|
"cummax",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummax(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"See :func:`cummax`.")
|
|
|
|
.def(
|
|
|
|
"cummin",
|
|
|
|
[](const array& a,
|
|
|
|
std::optional<int> axis,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
|
|
|
StreamOrDevice s) {
|
|
|
|
if (axis) {
|
|
|
|
return cummin(a, *axis, reverse, inclusive, s);
|
|
|
|
} else {
|
|
|
|
// TODO: Implement that in the C++ API as well. See concatenate
|
|
|
|
// above.
|
|
|
|
return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
|
|
|
}
|
|
|
|
},
|
2024-03-19 11:12:25 +08:00
|
|
|
"axis"_a = nb::none(),
|
|
|
|
nb::kw_only(),
|
2023-11-30 02:42:59 +08:00
|
|
|
"reverse"_a = false,
|
|
|
|
"inclusive"_a = true,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2023-12-19 03:32:48 +08:00
|
|
|
"See :func:`cummin`.")
|
|
|
|
.def(
|
|
|
|
"round",
|
|
|
|
[](const array& a, int decimals, StreamOrDevice s) {
|
|
|
|
return round(a, decimals, s);
|
|
|
|
},
|
|
|
|
"decimals"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-01-31 01:45:48 +08:00
|
|
|
"See :func:`round`.")
|
|
|
|
.def(
|
|
|
|
"diagonal",
|
|
|
|
[](const array& a,
|
|
|
|
int offset,
|
|
|
|
int axis1,
|
|
|
|
int axis2,
|
|
|
|
StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); },
|
|
|
|
"offset"_a = 0,
|
|
|
|
"axis1"_a = 0,
|
|
|
|
"axis2"_a = 1,
|
2024-03-19 11:12:25 +08:00
|
|
|
"stream"_a = nb::none(),
|
2024-01-31 01:45:48 +08:00
|
|
|
"See :func:`diagonal`.")
|
|
|
|
.def(
|
|
|
|
"diag",
|
|
|
|
[](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); },
|
|
|
|
"k"_a = 0,
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::kw_only(),
|
|
|
|
"stream"_a = nb::none(),
|
2024-01-31 01:45:48 +08:00
|
|
|
R"pbdoc(
|
|
|
|
Extract a diagonal or construct a diagonal matrix.
|
|
|
|
)pbdoc");
|
2024-04-01 21:27:52 +08:00
|
|
|
}
|