Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng 2024-12-12 08:45:39 +09:00 committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1423 additions and 1302 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,37 +14,37 @@
#define Py_bf_releasebuffer 2 #define Py_bf_releasebuffer 2
#endif #endif
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
std::string buffer_format(const array& a) { std::string buffer_format(const mx::array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters // https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case mx::bool_:
return "?"; return "?";
case uint8: case mx::uint8:
return "B"; return "B";
case uint16: case mx::uint16:
return "H"; return "H";
case uint32: case mx::uint32:
return "I"; return "I";
case uint64: case mx::uint64:
return "Q"; return "Q";
case int8: case mx::int8:
return "b"; return "b";
case int16: case mx::int16:
return "h"; return "h";
case int32: case mx::int32:
return "i"; return "i";
case int64: case mx::int64:
return "q"; return "q";
case float16: case mx::float16:
return "e"; return "e";
case float32: case mx::float32:
return "f"; return "f";
case bfloat16: case mx::bfloat16:
return "B"; return "B";
case complex64: case mx::complex64:
return "Zf\0"; return "Zf\0";
default: { default: {
std::ostringstream os; std::ostringstream os;
@ -84,7 +84,7 @@ struct buffer_info {
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer)); std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<array>(nb::handle(obj)); auto a = nb::cast<mx::array>(nb::handle(obj));
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;

View File

@ -16,7 +16,7 @@ enum PyScalarT {
namespace nanobind { namespace nanobind {
template <> template <>
struct ndarray_traits<float16_t> { struct ndarray_traits<mx::float16_t> {
static constexpr bool is_complex = false; static constexpr bool is_complex = false;
static constexpr bool is_float = true; static constexpr bool is_float = true;
static constexpr bool is_bool = false; static constexpr bool is_bool = false;
@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) {
} }
template <typename T> template <typename T>
array nd_array_to_mlx_contiguous( mx::array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const Shape& shape, const mx::Shape& shape,
Dtype dtype) { mx::Dtype dtype) {
// Make a copy of the numpy buffer // Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor // Get buffer ptr pass to array constructor
auto data_ptr = nd_array.data(); auto data_ptr = nd_array.data();
return array(static_cast<const T*>(data_ptr), shape, dtype); return mx::array(static_cast<const T*>(data_ptr), shape, dtype);
} }
array nd_array_to_mlx( mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) { std::optional<mx::Dtype> dtype) {
// Compute the shape and size // Compute the shape and size
Shape shape; mx::Shape shape;
for (int i = 0; i < nd_array.ndim(); i++) { for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i))); shape.push_back(check_shape_dim(nd_array.shape(i)));
} }
@ -59,49 +59,49 @@ array nd_array_to_mlx(
// Copy data and make array // Copy data and make array
if (type == nb::dtype<bool>()) { if (type == nb::dtype<bool>()) {
return nd_array_to_mlx_contiguous<bool>( return nd_array_to_mlx_contiguous<bool>(
nd_array, shape, dtype.value_or(bool_)); nd_array, shape, dtype.value_or(mx::bool_));
} else if (type == nb::dtype<uint8_t>()) { } else if (type == nb::dtype<uint8_t>()) {
return nd_array_to_mlx_contiguous<uint8_t>( return nd_array_to_mlx_contiguous<uint8_t>(
nd_array, shape, dtype.value_or(uint8)); nd_array, shape, dtype.value_or(mx::uint8));
} else if (type == nb::dtype<uint16_t>()) { } else if (type == nb::dtype<uint16_t>()) {
return nd_array_to_mlx_contiguous<uint16_t>( return nd_array_to_mlx_contiguous<uint16_t>(
nd_array, shape, dtype.value_or(uint16)); nd_array, shape, dtype.value_or(mx::uint16));
} else if (type == nb::dtype<uint32_t>()) { } else if (type == nb::dtype<uint32_t>()) {
return nd_array_to_mlx_contiguous<uint32_t>( return nd_array_to_mlx_contiguous<uint32_t>(
nd_array, shape, dtype.value_or(uint32)); nd_array, shape, dtype.value_or(mx::uint32));
} else if (type == nb::dtype<uint64_t>()) { } else if (type == nb::dtype<uint64_t>()) {
return nd_array_to_mlx_contiguous<uint64_t>( return nd_array_to_mlx_contiguous<uint64_t>(
nd_array, shape, dtype.value_or(uint64)); nd_array, shape, dtype.value_or(mx::uint64));
} else if (type == nb::dtype<int8_t>()) { } else if (type == nb::dtype<int8_t>()) {
return nd_array_to_mlx_contiguous<int8_t>( return nd_array_to_mlx_contiguous<int8_t>(
nd_array, shape, dtype.value_or(int8)); nd_array, shape, dtype.value_or(mx::int8));
} else if (type == nb::dtype<int16_t>()) { } else if (type == nb::dtype<int16_t>()) {
return nd_array_to_mlx_contiguous<int16_t>( return nd_array_to_mlx_contiguous<int16_t>(
nd_array, shape, dtype.value_or(int16)); nd_array, shape, dtype.value_or(mx::int16));
} else if (type == nb::dtype<int32_t>()) { } else if (type == nb::dtype<int32_t>()) {
return nd_array_to_mlx_contiguous<int32_t>( return nd_array_to_mlx_contiguous<int32_t>(
nd_array, shape, dtype.value_or(int32)); nd_array, shape, dtype.value_or(mx::int32));
} else if (type == nb::dtype<int64_t>()) { } else if (type == nb::dtype<int64_t>()) {
return nd_array_to_mlx_contiguous<int64_t>( return nd_array_to_mlx_contiguous<int64_t>(
nd_array, shape, dtype.value_or(int64)); nd_array, shape, dtype.value_or(mx::int64));
} else if (type == nb::dtype<float16_t>()) { } else if (type == nb::dtype<mx::float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>( return nd_array_to_mlx_contiguous<mx::float16_t>(
nd_array, shape, dtype.value_or(float16)); nd_array, shape, dtype.value_or(mx::float16));
} else if (type == nb::bfloat16) { } else if (type == nb::bfloat16) {
return nd_array_to_mlx_contiguous<bfloat16_t>( return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16)); nd_array, shape, dtype.value_or(mx::bfloat16));
} else if (type == nb::dtype<float>()) { } else if (type == nb::dtype<float>()) {
return nd_array_to_mlx_contiguous<float>( return nd_array_to_mlx_contiguous<float>(
nd_array, shape, dtype.value_or(float32)); nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<double>()) { } else if (type == nb::dtype<double>()) {
return nd_array_to_mlx_contiguous<double>( return nd_array_to_mlx_contiguous<double>(
nd_array, shape, dtype.value_or(float32)); nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<std::complex<float>>()) { } else if (type == nb::dtype<std::complex<float>>()) {
return nd_array_to_mlx_contiguous<complex64_t>( return nd_array_to_mlx_contiguous<mx::complex64_t>(
nd_array, shape, dtype.value_or(complex64)); nd_array, shape, dtype.value_or(mx::complex64));
} else if (type == nb::dtype<std::complex<double>>()) { } else if (type == nb::dtype<std::complex<double>>()) {
return nd_array_to_mlx_contiguous<complex128_t>( return nd_array_to_mlx_contiguous<mx::complex128_t>(
nd_array, shape, dtype.value_or(complex64)); nd_array, shape, dtype.value_or(mx::complex64));
} else { } else {
throw std::invalid_argument("Cannot convert numpy array to mlx array."); throw std::invalid_argument("Cannot convert numpy array to mlx array.");
} }
@ -109,7 +109,7 @@ array nd_array_to_mlx(
template <typename T, typename... NDParams> template <typename T, typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array_impl( nb::ndarray<NDParams...> mlx_to_nd_array_impl(
array a, mx::array a,
std::optional<nb::dlpack::dtype> t = {}) { std::optional<nb::dlpack::dtype> t = {}) {
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
@ -126,48 +126,48 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
} }
template <typename... NDParams> template <typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) { nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case mx::bool_:
return mlx_to_nd_array_impl<bool, NDParams...>(a); return mlx_to_nd_array_impl<bool, NDParams...>(a);
case uint8: case mx::uint8:
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a); return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
case uint16: case mx::uint16:
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a); return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
case uint32: case mx::uint32:
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a); return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
case uint64: case mx::uint64:
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a); return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
case int8: case mx::int8:
return mlx_to_nd_array_impl<int8_t, NDParams...>(a); return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
case int16: case mx::int16:
return mlx_to_nd_array_impl<int16_t, NDParams...>(a); return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
case int32: case mx::int32:
return mlx_to_nd_array_impl<int32_t, NDParams...>(a); return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
case int64: case mx::int64:
return mlx_to_nd_array_impl<int64_t, NDParams...>(a); return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
case float16: case mx::float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a); return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);
case bfloat16: case mx::bfloat16:
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case float32: case mx::float32:
return mlx_to_nd_array_impl<float, NDParams...>(a); return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64: case mx::complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a); return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default: default:
throw nb::type_error("type cannot be converted to NumPy."); throw nb::type_error("type cannot be converted to NumPy.");
} }
} }
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) { nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {
return mlx_to_nd_array<nb::numpy>(a); return mlx_to_nd_array<nb::numpy>(a);
} }
nb::ndarray<> mlx_to_dlpack(const array& a) { nb::ndarray<> mlx_to_dlpack(const mx::array& a) {
return mlx_to_nd_array<>(a); return mlx_to_nd_array<>(a);
} }
nb::object to_scalar(array& a) { nb::object to_scalar(mx::array& a) {
if (a.size() != 1) { if (a.size() != 1) {
throw std::invalid_argument( throw std::invalid_argument(
"[convert] Only length-1 arrays can be converted to Python scalars."); "[convert] Only length-1 arrays can be converted to Python scalars.");
@ -177,31 +177,31 @@ nb::object to_scalar(array& a) {
a.eval(); a.eval();
} }
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case mx::bool_:
return nb::cast(a.item<bool>()); return nb::cast(a.item<bool>());
case uint8: case mx::uint8:
return nb::cast(a.item<uint8_t>()); return nb::cast(a.item<uint8_t>());
case uint16: case mx::uint16:
return nb::cast(a.item<uint16_t>()); return nb::cast(a.item<uint16_t>());
case uint32: case mx::uint32:
return nb::cast(a.item<uint32_t>()); return nb::cast(a.item<uint32_t>());
case uint64: case mx::uint64:
return nb::cast(a.item<uint64_t>()); return nb::cast(a.item<uint64_t>());
case int8: case mx::int8:
return nb::cast(a.item<int8_t>()); return nb::cast(a.item<int8_t>());
case int16: case mx::int16:
return nb::cast(a.item<int16_t>()); return nb::cast(a.item<int16_t>());
case int32: case mx::int32:
return nb::cast(a.item<int32_t>()); return nb::cast(a.item<int32_t>());
case int64: case mx::int64:
return nb::cast(a.item<int64_t>()); return nb::cast(a.item<int64_t>());
case float16: case mx::float16:
return nb::cast(static_cast<float>(a.item<float16_t>())); return nb::cast(static_cast<float>(a.item<mx::float16_t>()));
case float32: case mx::float32:
return nb::cast(a.item<float>()); return nb::cast(a.item<float>());
case bfloat16: case mx::bfloat16:
return nb::cast(static_cast<float>(a.item<bfloat16_t>())); return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
case complex64: case mx::complex64:
return nb::cast(a.item<std::complex<float>>()); return nb::cast(a.item<std::complex<float>>());
default: default:
throw nb::type_error("type cannot be converted to Python scalar."); throw nb::type_error("type cannot be converted to Python scalar.");
@ -209,7 +209,7 @@ nb::object to_scalar(array& a) {
} }
template <typename T, typename U = T> template <typename T, typename U = T>
nb::list to_list(array& a, size_t index, int dim) { nb::list to_list(mx::array& a, size_t index, int dim) {
nb::list pl; nb::list pl;
auto stride = a.strides()[dim]; auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) { for (int i = 0; i < a.shape(dim); ++i) {
@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) {
return pl; return pl;
} }
nb::object tolist(array& a) { nb::object tolist(mx::array& a) {
if (a.ndim() == 0) { if (a.ndim() == 0) {
return to_scalar(a); return to_scalar(a);
} }
@ -232,31 +232,31 @@ nb::object tolist(array& a) {
a.eval(); a.eval();
} }
switch (a.dtype()) { switch (a.dtype()) {
case bool_: case mx::bool_:
return to_list<bool>(a, 0, 0); return to_list<bool>(a, 0, 0);
case uint8: case mx::uint8:
return to_list<uint8_t>(a, 0, 0); return to_list<uint8_t>(a, 0, 0);
case uint16: case mx::uint16:
return to_list<uint16_t>(a, 0, 0); return to_list<uint16_t>(a, 0, 0);
case uint32: case mx::uint32:
return to_list<uint32_t>(a, 0, 0); return to_list<uint32_t>(a, 0, 0);
case uint64: case mx::uint64:
return to_list<uint64_t>(a, 0, 0); return to_list<uint64_t>(a, 0, 0);
case int8: case mx::int8:
return to_list<int8_t>(a, 0, 0); return to_list<int8_t>(a, 0, 0);
case int16: case mx::int16:
return to_list<int16_t>(a, 0, 0); return to_list<int16_t>(a, 0, 0);
case int32: case mx::int32:
return to_list<int32_t>(a, 0, 0); return to_list<int32_t>(a, 0, 0);
case int64: case mx::int64:
return to_list<int64_t>(a, 0, 0); return to_list<int64_t>(a, 0, 0);
case float16: case mx::float16:
return to_list<float16_t, float>(a, 0, 0); return to_list<mx::float16_t, float>(a, 0, 0);
case float32: case mx::float32:
return to_list<float>(a, 0, 0); return to_list<float>(a, 0, 0);
case bfloat16: case mx::bfloat16:
return to_list<bfloat16_t, float>(a, 0, 0); return to_list<mx::bfloat16_t, float>(a, 0, 0);
case complex64: case mx::complex64:
return to_list<std::complex<float>>(a, 0, 0); return to_list<std::complex<float>>(a, 0, 0);
default: default:
throw nb::type_error("data type cannot be converted to Python list."); throw nb::type_error("data type cannot be converted to Python list.");
@ -279,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
template <typename T> template <typename T>
PyScalarT validate_shape( PyScalarT validate_shape(
T list, T list,
const Shape& shape, const mx::Shape& shape,
int idx, int idx,
bool& all_python_primitive_elements) { bool& all_python_primitive_elements) {
if (idx >= shape.size()) { if (idx >= shape.size()) {
@ -307,9 +307,9 @@ PyScalarT validate_shape(
shape, shape,
idx + 1, idx + 1,
all_python_primitive_elements); all_python_primitive_elements);
} else if (nb::isinstance<array>(l)) { } else if (nb::isinstance<mx::array>(l)) {
all_python_primitive_elements = false; all_python_primitive_elements = false;
auto arr = nb::cast<array>(l); auto arr = nb::cast<mx::array>(l);
if (arr.ndim() + idx + 1 == shape.size() && if (arr.ndim() + idx + 1 == shape.size() &&
std::equal( std::equal(
arr.shape().cbegin(), arr.shape().cbegin(),
@ -347,7 +347,7 @@ PyScalarT validate_shape(
} }
template <typename T> template <typename T>
void get_shape(T list, Shape& shape) { void get_shape(T list, mx::Shape& shape) {
shape.push_back(check_shape_dim(nb::len(list))); shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) { if (shape.back() > 0) {
auto l = list.begin(); auto l = list.begin();
@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) {
return get_shape(nb::cast<nb::list>(*l), shape); return get_shape(nb::cast<nb::list>(*l), shape);
} else if (nb::isinstance<nb::tuple>(*l)) { } else if (nb::isinstance<nb::tuple>(*l)) {
return get_shape(nb::cast<nb::tuple>(*l), shape); return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) { } else if (nb::isinstance<mx::array>(*l)) {
auto arr = nb::cast<array>(*l); auto arr = nb::cast<mx::array>(*l);
for (int i = 0; i < arr.ndim(); i++) { for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(arr.shape(i)); shape.push_back(arr.shape(i));
} }
@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) {
} }
template <typename T> template <typename T>
array array_from_list_impl( mx::array array_from_list_impl(
T pl, T pl,
const PyScalarT& inferred_type, const PyScalarT& inferred_type,
std::optional<Dtype> specified_type, std::optional<mx::Dtype> specified_type,
const Shape& shape) { const mx::Shape& shape) {
// Make the array // Make the array
switch (inferred_type) { switch (inferred_type) {
case pybool: { case pybool: {
std::vector<bool> vals; std::vector<bool> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_)); return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));
} }
case pyint: { case pyint: {
auto dtype = specified_type.value_or(int32); auto dtype = specified_type.value_or(mx::int32);
if (dtype == int64) { if (dtype == mx::int64) {
std::vector<int64_t> vals; std::vector<int64_t> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return mx::array(vals.begin(), shape, dtype);
} else if (dtype == uint64) { } else if (dtype == mx::uint64) {
std::vector<uint64_t> vals; std::vector<uint64_t> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return mx::array(vals.begin(), shape, dtype);
} else if (dtype == uint32) { } else if (dtype == mx::uint32) {
std::vector<uint32_t> vals; std::vector<uint32_t> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return mx::array(vals.begin(), shape, dtype);
} else if (issubdtype(dtype, inexact)) { } else if (mx::issubdtype(dtype, mx::inexact)) {
std::vector<float> vals; std::vector<float> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return mx::array(vals.begin(), shape, dtype);
} else { } else {
std::vector<int> vals; std::vector<int> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, dtype); return mx::array(vals.begin(), shape, dtype);
} }
} }
case pyfloat: { case pyfloat: {
std::vector<float> vals; std::vector<float> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32)); return mx::array(
vals.begin(), shape, specified_type.value_or(mx::float32));
} }
case pycomplex: { case pycomplex: {
std::vector<std::complex<float>> vals; std::vector<std::complex<float>> vals;
fill_vector(pl, vals); fill_vector(pl, vals);
return array( return mx::array(
reinterpret_cast<complex64_t*>(vals.data()), reinterpret_cast<mx::complex64_t*>(vals.data()),
shape, shape,
specified_type.value_or(complex64)); specified_type.value_or(mx::complex64));
} }
default: { default: {
std::ostringstream msg; std::ostringstream msg;
@ -425,9 +426,9 @@ array array_from_list_impl(
} }
template <typename T> template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) { mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
// Compute the shape // Compute the shape
Shape shape; mx::Shape shape;
get_shape(pl, shape); get_shape(pl, shape);
// Validate the shape and type // Validate the shape and type
@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
} }
// `pl` contains mlx arrays // `pl` contains mlx arrays
std::vector<array> arrays; std::vector<mx::array> arrays;
for (auto l : pl) { for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype)); arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
} }
return stack(arrays); return mx::stack(arrays);
} }
array array_from_list(nb::list pl, std::optional<Dtype> dtype) { mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype); return array_from_list_impl(pl, dtype);
} }
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) { mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype); return array_from_list_impl(pl, dtype);
} }
array create_array(ArrayInitType v, std::optional<Dtype> t) { mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) { if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), t.value_or(bool_)); return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) { } else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return array(nb::cast<int>(*pv), t.value_or(int32)); return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) { } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return array(nb::cast<float>(*pv), t.value_or(float32)); return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64)); return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) { } else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t); return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) { } else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional<Dtype> t) {
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v); nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) { pv) {
return nd_array_to_mlx(*pv, t); return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) { } else if (auto pv = std::get_if<mx::array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype())); return mx::astype(*pv, t.value_or((*pv).dtype()));
} else { } else {
auto arr = to_array_with_accessor(std::get<nb::object>(v)); auto arr = to_array_with_accessor(std::get<nb::object>(v));
return astype(arr, t.value_or(arr.dtype())); return mx::astype(arr, t.value_or(arr.dtype()));
} }
} }

View File

@ -9,15 +9,15 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/ops.h" #include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
using ArrayInitType = std::variant< using ArrayInitType = std::variant<
nb::bool_, nb::bool_,
nb::int_, nb::int_,
nb::float_, nb::float_,
// Must be above ndarray // Must be above ndarray
array, mx::array,
// Must be above complex // Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>, std::complex<float>,
@ -25,17 +25,17 @@ using ArrayInitType = std::variant<
nb::tuple, nb::tuple,
nb::object>; nb::object>;
array nd_array_to_mlx( mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype); std::optional<mx::Dtype> dtype);
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a); nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
nb::ndarray<> mlx_to_dlpack(const array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a);
nb::object to_scalar(array& a); nb::object to_scalar(mx::array& a);
nb::object tolist(array& a); nb::object tolist(mx::array& a);
array create_array(ArrayInitType v, std::optional<Dtype> t); mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
array array_from_list(nb::list pl, std::optional<Dtype> dtype); mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype); mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);

View File

@ -8,51 +8,54 @@
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
void init_device(nb::module_& m) { void init_device(nb::module_& m) {
auto device_class = nb::class_<Device>( auto device_class = nb::class_<mx::Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
nb::enum_<Device::DeviceType>(m, "DeviceType") nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu) .value("cpu", mx::Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu) .value("gpu", mx::Device::DeviceType::gpu)
.export_values() .export_values()
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) { .def(
if (!nb::isinstance<Device>(other) && "__eq__",
!nb::isinstance<Device::DeviceType>(other)) { [](const mx::Device::DeviceType& d, const nb::object& other) {
return false; if (!nb::isinstance<mx::Device>(other) &&
} !nb::isinstance<mx::Device::DeviceType>(other)) {
return d == nb::cast<Device>(other); return false;
}); }
return d == nb::cast<mx::Device>(other);
});
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0) device_class
.def_ro("type", &Device::type) .def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &mx::Device::type)
.def( .def(
"__repr__", "__repr__",
[](const Device& d) { [](const mx::Device& d) {
std::ostringstream os; std::ostringstream os;
os << d; os << d;
return os.str(); return os.str();
}) })
.def("__eq__", [](const Device& d, const nb::object& other) { .def("__eq__", [](const mx::Device& d, const nb::object& other) {
if (!nb::isinstance<Device>(other) && if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) { !nb::isinstance<mx::Device::DeviceType>(other)) {
return false; return false;
} }
return d == nb::cast<Device>(other); return d == nb::cast<mx::Device>(other);
}); });
nb::implicitly_convertible<Device::DeviceType, Device>(); nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def( m.def(
"default_device", "default_device",
&default_device, &mx::default_device,
R"pbdoc(Get the default device.)pbdoc"); R"pbdoc(Get the default device.)pbdoc");
m.def( m.def(
"set_default_device", "set_default_device",
&set_default_device, &mx::set_default_device,
"device"_a, "device"_a,
R"pbdoc(Set the default device.)pbdoc"); R"pbdoc(Set the default device.)pbdoc");
} }

View File

@ -9,26 +9,27 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
void init_distributed(nb::module_& parent_module) { void init_distributed(nb::module_& parent_module) {
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"distributed", "mlx.core.distributed: Communication operations"); "distributed", "mlx.core.distributed: Communication operations");
nb::class_<distributed::Group>( nb::class_<mx::distributed::Group>(
m, m,
"Group", "Group",
R"pbcopy( R"pbcopy(
An :class:`mlx.core.distributed.Group` represents a group of independent mlx An :class:`mlx.core.distributed.Group` represents a group of independent mlx
processes that can communicate. processes that can communicate.
)pbcopy") )pbcopy")
.def("rank", &distributed::Group::rank, "Get the rank of this process") .def(
.def("size", &distributed::Group::size, "Get the size of the group") "rank", &mx::distributed::Group::rank, "Get the rank of this process")
.def("size", &mx::distributed::Group::size, "Get the size of the group")
.def( .def(
"split", "split",
&distributed::Group::split, &mx::distributed::Group::split,
"color"_a, "color"_a,
"key"_a = -1, "key"_a = -1,
nb::sig("def split(self, color: int, key: int = -1) -> Group"), nb::sig("def split(self, color: int, key: int = -1) -> Group"),
@ -48,14 +49,14 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"is_available", "is_available",
&distributed::is_available, &mx::distributed::is_available,
R"pbdoc( R"pbdoc(
Check if a communication backend is available. Check if a communication backend is available.
)pbdoc"); )pbdoc");
m.def( m.def(
"init", "init",
&distributed::init, &mx::distributed::init,
"strict"_a = false, "strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"), nb::sig("def init(strict: bool = False) -> Group"),
R"pbdoc( R"pbdoc(
@ -72,7 +73,7 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"all_sum", "all_sum",
&distributed::all_sum, &mx::distributed::all_sum,
"x"_a, "x"_a,
nb::kw_only(), nb::kw_only(),
"group"_a = nb::none(), "group"_a = nb::none(),
@ -98,7 +99,7 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"all_gather", "all_gather",
&distributed::all_gather, &mx::distributed::all_gather,
"x"_a, "x"_a,
nb::kw_only(), nb::kw_only(),
"group"_a = nb::none(), "group"_a = nb::none(),
@ -125,7 +126,7 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"send", "send",
&distributed::send, &mx::distributed::send,
"x"_a, "x"_a,
"dst"_a, "dst"_a,
nb::kw_only(), nb::kw_only(),
@ -152,7 +153,7 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"recv", "recv",
&distributed::recv, &mx::distributed::recv,
"shape"_a, "shape"_a,
"dtype"_a, "dtype"_a,
"src"_a, "src"_a,
@ -181,7 +182,7 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"recv_like", "recv_like",
&distributed::recv_like, &mx::distributed::recv_like,
"x"_a, "x"_a,
"src"_a, "src"_a,
nb::kw_only(), nb::kw_only(),

View File

@ -13,9 +13,9 @@
#include "mlx/fast.h" #include "mlx/fast.h"
#include "mlx/ops.h" #include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
void init_fast(nb::module_& parent_module) { void init_fast(nb::module_& parent_module) {
auto m = auto m =
@ -23,7 +23,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"rms_norm", "rms_norm",
&fast::rms_norm, &mx::fast::rms_norm,
"x"_a, "x"_a,
"weight"_a, "weight"_a,
"eps"_a, "eps"_a,
@ -49,7 +49,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"layer_norm", "layer_norm",
&fast::layer_norm, &mx::fast::layer_norm,
"x"_a, "x"_a,
"weight"_a.none(), "weight"_a.none(),
"bias"_a.none(), "bias"_a.none(),
@ -79,7 +79,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"rope", "rope",
&fast::rope, &mx::fast::rope,
"a"_a, "a"_a,
"dims"_a, "dims"_a,
nb::kw_only(), nb::kw_only(),
@ -114,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"scaled_dot_product_attention", "scaled_dot_product_attention",
&fast::scaled_dot_product_attention, &mx::fast::scaled_dot_product_attention,
"q"_a, "q"_a,
"k"_a, "k"_a,
"v"_a, "v"_a,
@ -170,7 +170,7 @@ void init_fast(nb::module_& parent_module) {
const std::string& header, const std::string& header,
bool ensure_row_contiguous, bool ensure_row_contiguous,
bool atomic_outputs) { bool atomic_outputs) {
auto kernel = fast::metal_kernel( auto kernel = mx::fast::metal_kernel(
name, name,
input_names, input_names,
output_names, output_names,
@ -182,7 +182,7 @@ void init_fast(nb::module_& parent_module) {
[kernel = std::move(kernel)]( [kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_, const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes, const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes, const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid, std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup, std::tuple<int, int, int> threadgroup,
const std::optional< const std::optional<
@ -190,12 +190,12 @@ void init_fast(nb::module_& parent_module) {
template_args_ = std::nullopt, template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt, std::optional<float> init_value = std::nullopt,
bool verbose = false, bool verbose = false,
StreamOrDevice s = {}) { mx::StreamOrDevice s = {}) {
std::vector<array> inputs; std::vector<mx::array> inputs;
for (const auto& value : inputs_) { for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt)); inputs.push_back(to_array(value, std::nullopt));
} }
std::vector<std::pair<std::string, fast::TemplateArg>> std::vector<std::pair<std::string, mx::fast::TemplateArg>>
template_args; template_args;
if (template_args_) { if (template_args_) {
for (const auto& [name, value] : template_args_.value()) { for (const auto& [name, value] : template_args_.value()) {
@ -206,8 +206,8 @@ void init_fast(nb::module_& parent_module) {
} else if (nb::isinstance<int>(value)) { } else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value); int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val); template_args.emplace_back(name, int_val);
} else if (nb::isinstance<Dtype>(value)) { } else if (nb::isinstance<mx::Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value); mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype); template_args.emplace_back(name, dtype);
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -9,24 +9,23 @@
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/ops.h" #include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
void init_fft(nb::module_& parent_module) { void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms."); "fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def( m.def(
"fft", "fft",
[](const array& a, [](const mx::array& a,
const std::optional<int>& n, const std::optional<int>& n,
int axis, int axis,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (n.has_value()) { if (n.has_value()) {
return fft::fft(a, n.value(), axis, s); return mx::fft::fft(a, n.value(), axis, s);
} else { } else {
return fft::fft(a, axis, s); return mx::fft::fft(a, axis, s);
} }
}, },
"a"_a, "a"_a,
@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"ifft", "ifft",
[](const array& a, [](const mx::array& a,
const std::optional<int>& n, const std::optional<int>& n,
int axis, int axis,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (n.has_value()) { if (n.has_value()) {
return fft::ifft(a, n.value(), axis, s); return mx::fft::ifft(a, n.value(), axis, s);
} else { } else {
return fft::ifft(a, axis, s); return mx::fft::ifft(a, axis, s);
} }
}, },
"a"_a, "a"_a,
@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"fft2", "fft2",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s); return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s); return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[fft2] `axes` should not be `None` if `s` is not `None`."); "[fft2] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::fftn(a, s); return mx::fft::fftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"ifft2", "ifft2",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s); return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s); return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[ifft2] `axes` should not be `None` if `s` is not `None`."); "[ifft2] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::ifftn(a, s); return mx::fft::ifftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"fftn", "fftn",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s); return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s); return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[fftn] `axes` should not be `None` if `s` is not `None`."); "[fftn] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::fftn(a, s); return mx::fft::fftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"ifftn", "ifftn",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s); return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s); return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[ifftn] `axes` should not be `None` if `s` is not `None`."); "[ifftn] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::ifftn(a, s); return mx::fft::ifftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"rfft", "rfft",
[](const array& a, [](const mx::array& a,
const std::optional<int>& n, const std::optional<int>& n,
int axis, int axis,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (n.has_value()) { if (n.has_value()) {
return fft::rfft(a, n.value(), axis, s); return mx::fft::rfft(a, n.value(), axis, s);
} else { } else {
return fft::rfft(a, axis, s); return mx::fft::rfft(a, axis, s);
} }
}, },
"a"_a, "a"_a,
@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"irfft", "irfft",
[](const array& a, [](const mx::array& a,
const std::optional<int>& n, const std::optional<int>& n,
int axis, int axis,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (n.has_value()) { if (n.has_value()) {
return fft::irfft(a, n.value(), axis, s); return mx::fft::irfft(a, n.value(), axis, s);
} else { } else {
return fft::irfft(a, axis, s); return mx::fft::irfft(a, axis, s);
} }
}, },
"a"_a, "a"_a,
@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"rfft2", "rfft2",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s); return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s); return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[rfft2] `axes` should not be `None` if `s` is not `None`."); "[rfft2] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::rfftn(a, s); return mx::fft::rfftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"irfft2", "irfft2",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s); return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s); return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[irfft2] `axes` should not be `None` if `s` is not `None`."); "[irfft2] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::irfftn(a, s); return mx::fft::irfftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"rfftn", "rfftn",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s); return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s); return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[rfftn] `axes` should not be `None` if `s` is not `None`."); "[rfftn] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::rfftn(a, s); return mx::fft::rfftn(a, s);
} }
}, },
"a"_a, "a"_a,
@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"irfftn", "irfftn",
[](const array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s); return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) { } else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s); return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) { } else if (n.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[irfftn] `axes` should not be `None` if `s` is not `None`."); "[irfftn] `axes` should not be `None` if `s` is not `None`.");
} else { } else {
return fft::irfftn(a, s); return mx::fft::irfftn(a, s);
} }
}, },
"a"_a, "a"_a,

View File

@ -43,20 +43,20 @@ void get_slice_params(
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
} }
array get_int_index(nb::object idx, int axis_size) { mx::array get_int_index(nb::object idx, int axis_size) {
int idx_ = nb::cast<int>(idx); int idx_ = nb::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32); return mx::array(idx_, mx::uint32);
} }
bool is_valid_index_type(const nb::object& obj) { bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) || return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) || nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::isinstance<nb::list>(obj); nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
} }
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) { mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -77,14 +77,14 @@ array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
return slice(src, starts, ends, strides); return slice(src, starts, ends, strides);
} }
array mlx_get_item_array(const array& src, const array& indices) { mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"too many indices for array: array is 0-dimensional"); "too many indices for array: array is 0-dimensional");
} }
if (indices.dtype() == bool_) { if (indices.dtype() == mx::bool_) {
throw std::invalid_argument("boolean indices are not yet supported"); throw std::invalid_argument("boolean indices are not yet supported");
} }
@ -93,7 +93,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
return take(src, indices, 0); return take(src, indices, 0);
} }
array mlx_get_item_int(const array& src, const nb::int_& idx) { mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -105,13 +105,13 @@ array mlx_get_item_int(const array& src, const nb::int_& idx) {
return take(src, get_int_index(idx, src.shape(0)), 0); return take(src, get_int_index(idx, src.shape(0)), 0);
} }
array mlx_gather_nd( mx::array mlx_gather_nd(
array src, mx::array src,
const std::vector<nb::object>& indices, const std::vector<nb::object>& indices,
bool gather_first, bool gather_first,
int& max_dims) { int& max_dims) {
max_dims = 0; max_dims = 0;
std::vector<array> gather_indices; std::vector<mx::array> gather_indices;
std::vector<bool> is_slice(indices.size(), false); std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0; int num_slices = 0;
// gather all the arrays // gather all the arrays
@ -127,13 +127,13 @@ array mlx_gather_nd(
start = (start < 0) ? start + src.shape(i) : start; start = (start < 0) ? start + src.shape(i) : start;
end = (end < 0) ? end + src.shape(i) : end; end = (end < 0) ? end + src.shape(i) : end;
gather_indices.push_back(arange(start, end, stride, uint32)); gather_indices.push_back(arange(start, end, stride, mx::uint32));
num_slices++; num_slices++;
is_slice[i] = true; is_slice[i] = true;
} else if (nb::isinstance<nb::int_>(idx)) { } else if (nb::isinstance<nb::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i))); gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (nb::isinstance<array>(idx)) { } else if (nb::isinstance<mx::array>(idx)) {
auto arr = nb::cast<array>(idx); auto arr = nb::cast<mx::array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims); max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr); gather_indices.push_back(arr);
} }
@ -144,7 +144,7 @@ array mlx_gather_nd(
int slice_index = 0; int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) { if (is_slice[i]) {
Shape index_shape(max_dims + num_slices, 1); mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0); index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++; slice_index++;
@ -158,7 +158,7 @@ array mlx_gather_nd(
// reshape them so that the int/array indices are last // reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) { if (i < num_slices) {
Shape index_shape(max_dims + num_slices, 1); mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0); index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
} }
@ -241,7 +241,7 @@ auto mlx_expand_ellipsis(
return std::make_pair(non_none_indices, indices); return std::make_pair(non_none_indices, indices);
} }
array mlx_get_item_nd(array src, const nb::tuple& entries) { mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// No indices make this a noop // No indices make this a noop
if (entries.size() == 0) { if (entries.size() == 0) {
return src; return src;
@ -281,7 +281,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
bool have_non_array = false; bool have_non_array = false;
bool gather_first = false; bool gather_first = false;
for (auto& idx : indices) { for (auto& idx : indices) {
if (nb::isinstance<array>(idx) || (nb::isinstance<nb::int_>(idx))) { if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
if (have_array && have_non_array) { if (have_array && have_non_array) {
gather_first = true; gather_first = true;
break; break;
@ -294,7 +294,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
int n_arr = 0; int n_arr = 0;
for (auto& idx : indices) { for (auto& idx : indices) {
n_arr += nb::isinstance<array>(idx); n_arr += nb::isinstance<mx::array>(idx);
} }
have_array &= n_arr > 0; have_array &= n_arr > 0;
@ -304,7 +304,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
// Then find the last array // Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) { for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array]; auto& idx = indices[last_array];
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) { if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
} }
} }
@ -340,7 +340,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
} else { } else {
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) { if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
} else if (idx.is_none()) { } else if (idx.is_none()) {
remaining_indices.push_back(idx); remaining_indices.push_back(idx);
@ -426,11 +426,11 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
return src; return src;
} }
array mlx_get_item(const array& src, const nb::object& obj) { mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
if (nb::isinstance<nb::slice>(obj)) { if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj)); return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<array>(obj)) { } else if (nb::isinstance<mx::array>(obj)) {
return mlx_get_item_array(src, nb::cast<array>(obj)); return mlx_get_item_array(src, nb::cast<mx::array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj)); return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (nb::isinstance<nb::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {
@ -448,10 +448,11 @@ array mlx_get_item(const array& src, const nb::object& obj) {
throw std::invalid_argument("Cannot index mlx array using the given type."); throw std::invalid_argument("Cannot index mlx array using the given type.");
} }
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int( std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
const array& src, mlx_scatter_args_int(
const mx::array& src,
const nb::int_& idx, const nb::int_& idx,
const array& update) { const mx::array& update) {
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"too many indices for array: array is 0-dimensional"); "too many indices for array: array is 0-dimensional");
@ -473,10 +474,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
{0}}; {0}};
} }
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array( std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
const array& src, mlx_scatter_args_array(
const array& indices, const mx::array& src,
const array& update) { const mx::array& indices,
const mx::array& update) {
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"too many indices for array: array is 0-dimensional"); "too many indices for array: array is 0-dimensional");
@ -500,10 +502,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
return {{indices}, up, {0}}; return {{indices}, up, {0}};
} }
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice( std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
const array& src, mlx_scatter_args_slice(
const mx::array& src,
const nb::slice& in_slice, const nb::slice& in_slice,
const array& update) { const mx::array& update) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -539,7 +542,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
auto up = reshape(update, up_shape); auto up = reshape(update, up_shape);
// Build array to mark start of slice // Build array to mark start of slice
auto idx = array({start}, {1}, uint32); auto idx = mx::array({start}, {1}, mx::uint32);
// Get slice size // Get slice size
int slice_size = (end - start); int slice_size = (end - start);
@ -551,20 +554,21 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
up = broadcast_to(up, up_shape_broadcast); up = broadcast_to(up, up_shape_broadcast);
auto indices = std::vector<array>{idx}; auto indices = std::vector<mx::array>{idx};
auto axes = std::vector<int>{0}; auto axes = std::vector<int>{0};
return {indices, up, axes}; return {indices, up, axes};
} }
return mlx_scatter_args_array( return mlx_scatter_args_array(
src, arange(start, end, stride, uint32), update); src, arange(start, end, stride, mx::uint32), update);
} }
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd( std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
const array& src, mlx_scatter_args_nd(
const mx::array& src,
const nb::tuple& entries, const nb::tuple& entries,
const array& update) { const mx::array& update) {
// Expand ellipses into a series of ':' slices // Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
@ -623,12 +627,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
num_simple_slices_post++; num_simple_slices_post++;
} }
} else if (nb::isinstance<array>(idx)) { } else if (nb::isinstance<mx::array>(idx)) {
have_array = true; have_array = true;
if (have_array && have_non_array) { if (have_array && have_non_array) {
arrays_first = true; arrays_first = true;
} }
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim); max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
num_arrays++; num_arrays++;
num_simple_slices_post = 0; num_simple_slices_post = 0;
} }
@ -643,7 +647,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim; idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
// Go over each index type and translate to the needed scatter args // Go over each index type and translate to the needed scatter args
std::vector<array> arr_indices; std::vector<mx::array> arr_indices;
int slice_num = 0; int slice_num = 0;
int array_num = 0; int array_num = 0;
int ax = 0; int ax = 0;
@ -668,7 +672,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
// If it's a simple slice, we only need to add the start index // If it's a simple slice, we only need to add the start index
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) { if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
auto idx = array({start}, idx_shape, uint32); auto idx = mx::array({start}, idx_shape, mx::uint32);
slice_shapes.push_back(end - start); slice_shapes.push_back(end - start);
arr_indices.push_back(idx); arr_indices.push_back(idx);
@ -677,7 +681,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
} }
// Otherwise we expand the slice into indices using arange // Otherwise we expand the slice into indices using arange
else { else {
auto idx = arange(start, end, stride, uint32); auto idx = arange(start, end, stride, mx::uint32);
auto loc = slice_num + (arrays_first ? max_dim : 0); auto loc = slice_num + (arrays_first ? max_dim : 0);
idx_shape[loc] = idx.size(); idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape)); arr_indices.push_back(reshape(idx, idx_shape));
@ -696,9 +700,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
} else if (pyidx.is_none()) { } else if (pyidx.is_none()) {
// We only use the None's for bookeeping dimensions // We only use the None's for bookeeping dimensions
slice_num++; slice_num++;
} else if (nb::isinstance<array>(pyidx)) { } else if (nb::isinstance<mx::array>(pyidx)) {
ax++; ax++;
auto idx = nb::cast<array>(pyidx); auto idx = nb::cast<mx::array>(pyidx);
std::vector<int> idx_shape(idx_ndim, 1); std::vector<int> idx_shape(idx_ndim, 1);
// Place the arrays in the correct dimension // Place the arrays in the correct dimension
@ -748,16 +752,16 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
return {arr_indices, up, axes}; return {arr_indices, up, axes};
} }
std::tuple<std::vector<array>, array, std::vector<int>> std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_compute_scatter_args( mlx_compute_scatter_args(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype()); auto vals = to_array(v, src.dtype());
if (nb::isinstance<nb::slice>(obj)) { if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals); return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<array>(obj)) { } else if (nb::isinstance<mx::array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals); return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals); return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (nb::isinstance<nb::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {
@ -773,7 +777,7 @@ mlx_compute_scatter_args(
} }
auto mlx_slice_update( auto mlx_slice_update(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
// Can't route to slice update if not slice or tuple // Can't route to slice update if not slice or tuple
@ -784,7 +788,7 @@ auto mlx_slice_update(
if (nb::isinstance<nb::tuple>(obj)) { if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present // Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) { for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) { if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
return std::make_pair(false, src); return std::make_pair(false, src);
} }
} }
@ -881,7 +885,10 @@ auto mlx_slice_update(
return std::make_pair(true, out); return std::make_pair(true, out);
} }
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) { void mlx_set_item(
mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [success, out] = mlx_slice_update(src, obj, v); auto [success, out] = mlx_slice_update(src, obj, v);
if (success) { if (success) {
src.overwrite_descriptor(out); src.overwrite_descriptor(out);
@ -897,8 +904,8 @@ void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
} }
} }
array mlx_add_item( mx::array mlx_add_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -909,8 +916,8 @@ array mlx_add_item(
} }
} }
array mlx_subtract_item( mx::array mlx_subtract_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -921,8 +928,8 @@ array mlx_subtract_item(
} }
} }
array mlx_multiply_item( mx::array mlx_multiply_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -933,8 +940,8 @@ array mlx_multiply_item(
} }
} }
array mlx_divide_item( mx::array mlx_divide_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -945,8 +952,8 @@ array mlx_divide_item(
} }
} }
array mlx_maximum_item( mx::array mlx_maximum_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -957,8 +964,8 @@ array mlx_maximum_item(
} }
} }
array mlx_minimum_item( mx::array mlx_minimum_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);

View File

@ -7,32 +7,35 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "python/src/utils.h" #include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
array mlx_get_item(const array& src, const nb::object& obj); mx::array mlx_get_item(const mx::array& src, const nb::object& obj);
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v); void mlx_set_item(
array mlx_add_item( mx::array& src,
const array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_subtract_item( mx::array mlx_add_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_multiply_item( mx::array mlx_subtract_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_divide_item( mx::array mlx_multiply_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_maximum_item( mx::array mlx_divide_item(
const array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_minimum_item( mx::array mlx_maximum_item(
const array& src, const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
mx::array mlx_minimum_item(
const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);

View File

@ -10,15 +10,13 @@
#include "mlx/linalg.h" #include "mlx/linalg.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
namespace { namespace {
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = svd(a, s); const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2)); return nb::make_tuple(result.at(0), result.at(1), result.at(2));
} }
} // namespace } // namespace
@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
m.def( m.def(
"norm", "norm",
[](const array& a, [](const mx::array& a,
const std::variant<std::monostate, int, double, std::string>& ord_, const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_, const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims, const bool keepdims,
const StreamOrDevice stream) { const mx::StreamOrDevice stream) {
std::optional<std::vector<int>> axis = std::nullopt; std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) { if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv}; axis = std::vector<int>{*pv};
@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
} }
if (std::holds_alternative<std::monostate>(ord_)) { if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream); return mx::linalg::norm(a, axis, keepdims, stream);
} else { } else {
if (auto pv = std::get_if<std::string>(&ord_); pv) { if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream); return mx::linalg::norm(a, *pv, axis, keepdims, stream);
} }
double ord; double ord;
if (auto pv = std::get_if<int>(&ord_); pv) { if (auto pv = std::get_if<int>(&ord_); pv) {
@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
} else { } else {
ord = std::get<double>(ord_); ord = std::get<double>(ord_);
} }
return norm(a, ord, axis, keepdims, stream); return mx::linalg::norm(a, ord, axis, keepdims, stream);
} }
}, },
nb::arg(), nb::arg(),
@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"qr", "qr",
&qr, &mx::linalg::qr,
"a"_a, "a"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"inv", "inv",
&inv, &mx::linalg::inv,
"a"_a, "a"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"tri_inv", "tri_inv",
&tri_inv, &mx::linalg::tri_inv,
"a"_a, "a"_a,
"upper"_a, "upper"_a,
nb::kw_only(), nb::kw_only(),
@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"cholesky", "cholesky",
&cholesky, &mx::linalg::cholesky,
"a"_a, "a"_a,
"upper"_a = false, "upper"_a = false,
nb::kw_only(), nb::kw_only(),
@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"cholesky_inv", "cholesky_inv",
&cholesky_inv, &mx::linalg::cholesky_inv,
"a"_a, "a"_a,
"upper"_a = false, "upper"_a = false,
nb::kw_only(), nb::kw_only(),
@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"pinv", "pinv",
&pinv, &mx::linalg::pinv,
"a"_a, "a"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"cross", "cross",
&cross, &mx::linalg::cross,
"a"_a, "a"_a,
"b"_a, "b"_a,
"axis"_a = -1, "axis"_a = -1,
@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"eigvalsh", "eigvalsh",
&eigvalsh, &mx::linalg::eigvalsh,
"a"_a, "a"_a,
"UPLO"_a = "L", "UPLO"_a = "L",
nb::kw_only(), nb::kw_only(),
@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"eigh", "eigh",
[](const array& a, const std::string UPLO, StreamOrDevice s) { [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
// TODO avoid cast? // TODO avoid cast?
auto result = eigh(a, UPLO, s); auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second); return nb::make_tuple(result.first, result.second);
}, },
"a"_a, "a"_a,

View File

@ -14,9 +14,9 @@
#include "python/src/load.h" #include "python/src/load.h"
#include "python/src/utils.h" #include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Helpers // Helpers
@ -86,7 +86,7 @@ class ZipFileWrapper {
// Loading // Loading
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class PyFileReader : public io::Reader { class PyFileReader : public mx::io::Reader {
public: public:
PyFileReader(nb::object file) PyFileReader(nb::object file)
: pyistream_(file), : pyistream_(file),
@ -168,14 +168,14 @@ class PyFileReader : public io::Reader {
}; };
std::pair< std::pair<
std::unordered_map<std::string, array>, std::unordered_map<std::string, mx::array>,
std::unordered_map<std::string, std::string>> std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s); return mx::load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s); auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
{ {
nb::gil_scoped_release gil; nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) { for (auto& [key, arr] : std::get<0>(res)) {
@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string"); "[load_safetensors] Input must be a file-like object, or string");
} }
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) { mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s); return mx::load_gguf(nb::cast<std::string>(file), s);
} }
throw std::invalid_argument("[load_gguf] Input must be a string"); throw std::invalid_argument("[load_gguf] Input must be a string");
} }
std::unordered_map<std::string, array> mlx_load_npz_helper( std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
nb::object file, nb::object file,
StreamOrDevice s) { mx::StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file); bool own_file = nb::isinstance<nb::str>(file);
nb::module_ zipfile = nb::module_::import_("zipfile"); nb::module_ zipfile = nb::module_::import_("zipfile");
@ -209,7 +209,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
"opened with zipfile.ZipFile"); "opened with zipfile.ZipFile");
} }
// Output dictionary filename in zip -> loaded array // Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict; std::unordered_map<std::string, mx::array> array_dict;
// Create python ZipFile object // Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file); ZipFileWrapper zipfile_object(zipfile, file);
@ -218,7 +218,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
nb::object sub_file = zipfile_object.open(st); nb::object sub_file = zipfile_object.open(st);
// Create array from python file stream // Create array from python file stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s); auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there // Remove .npy from file if it is there
auto key = st; auto key = st;
@ -240,12 +240,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict; return array_dict;
} }
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) { mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s); return mx::load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s); auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
{ {
nb::gil_scoped_release gil; nb::gil_scoped_release gil;
arr.eval(); arr.eval();
@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper(
nb::object file, nb::object file,
std::optional<std::string> format, std::optional<std::string> format,
bool return_metadata, bool return_metadata,
StreamOrDevice s) { mx::StreamOrDevice s) {
if (!format.has_value()) { if (!format.has_value()) {
std::string fname; std::string fname;
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper(
// Saving // Saving
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class PyFileWriter : public io::Writer { class PyFileWriter : public mx::io::Writer {
public: public:
PyFileWriter(nb::object file) PyFileWriter(nb::object file)
: pyostream_(file), : pyostream_(file),
@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer {
nb::object tell_func_; nb::object tell_func_;
}; };
void mlx_save_helper(nb::object file, array a) { void mlx_save_helper(nb::object file, mx::array a) {
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a); mx::save(nb::cast<std::string>(file), a);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
nb::gil_scoped_release gil; nb::gil_scoped_release gil;
save(writer, a); mx::save(writer, a);
} }
return; return;
@ -419,8 +419,9 @@ void mlx_savez_helper(
} }
// Collect args and kwargs // Collect args and kwargs
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs); auto arrays_dict =
auto arrays_list = nb::cast<std::vector<array>>(args); nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
for (int i = 0; i < arrays_list.size(); i++) { for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i); std::string arr_name = "arr_" + std::to_string(i);
@ -447,7 +448,7 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream); auto writer = std::make_shared<PyFileWriter>(py_ostream);
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save(writer, a); mx::save(writer, a);
} }
} }
@ -470,17 +471,18 @@ void mlx_save_safetensor_helper(
} else { } else {
metadata_map = std::unordered_map<std::string, std::string>(); metadata_map = std::unordered_map<std::string, std::string>();
} }
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d); auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map); mx::save_safetensors(
nb::cast<std::string>(file), arrays_map, metadata_map);
} }
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map); mx::save_safetensors(writer, arrays_map, metadata_map);
} }
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
@ -492,19 +494,20 @@ void mlx_save_gguf_helper(
nb::object file, nb::object file,
nb::dict a, nb::dict a,
std::optional<nb::dict> m) { std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a); auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
if (m) { if (m) {
auto metadata_map = auto metadata_map =
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value()); nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
m.value());
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map); mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
} }
} else { } else {
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map); mx::save_gguf(nb::cast<std::string>(file), arrays_map);
} }
} }
} else { } else {

View File

@ -14,22 +14,24 @@
#include <variant> #include <variant>
#include "mlx/io.h" #include "mlx/io.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
using LoadOutputTypes = std::variant< using LoadOutputTypes = std::variant<
array, mx::array,
std::unordered_map<std::string, array>, std::unordered_map<std::string, mx::array>,
SafetensorsLoad, mx::SafetensorsLoad,
GGUFLoad>; mx::GGUFLoad>;
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s); mx::SafetensorsLoad mlx_load_safetensor_helper(
nb::object file,
mx::StreamOrDevice s);
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(
nb::object file, nb::object file,
nb::dict d, nb::dict d,
std::optional<nb::dict> m); std::optional<nb::dict> m);
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s); mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s);
void mlx_save_gguf_helper( void mlx_save_gguf_helper(
nb::object file, nb::object file,
@ -40,8 +42,8 @@ LoadOutputTypes mlx_load_helper(
nb::object file, nb::object file,
std::optional<std::string> format, std::optional<std::string> format,
bool return_metadata, bool return_metadata,
StreamOrDevice s); mx::StreamOrDevice s);
void mlx_save_helper(nb::object file, array a); void mlx_save_helper(nb::object file, mx::array a);
void mlx_savez_helper( void mlx_savez_helper(
nb::object file, nb::object file,
nb::args args, nb::args args,

View File

@ -8,22 +8,21 @@
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
void init_metal(nb::module_& m) { void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal"); nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def( metal.def(
"is_available", "is_available",
&metal::is_available, &mx::metal::is_available,
R"pbdoc( R"pbdoc(
Check if the Metal back-end is available. Check if the Metal back-end is available.
)pbdoc"); )pbdoc");
metal.def( metal.def(
"get_active_memory", "get_active_memory",
&metal::get_active_memory, &mx::metal::get_active_memory,
R"pbdoc( R"pbdoc(
Get the actively used memory in bytes. Get the actively used memory in bytes.
@ -32,7 +31,7 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"get_peak_memory", "get_peak_memory",
&metal::get_peak_memory, &mx::metal::get_peak_memory,
R"pbdoc( R"pbdoc(
Get the peak amount of used memory in bytes. Get the peak amount of used memory in bytes.
@ -41,13 +40,13 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"reset_peak_memory", "reset_peak_memory",
&metal::reset_peak_memory, &mx::metal::reset_peak_memory,
R"pbdoc( R"pbdoc(
Reset the peak memory to zero. Reset the peak memory to zero.
)pbdoc"); )pbdoc");
metal.def( metal.def(
"get_cache_memory", "get_cache_memory",
&metal::get_cache_memory, &mx::metal::get_cache_memory,
R"pbdoc( R"pbdoc(
Get the cache size in bytes. Get the cache size in bytes.
@ -56,7 +55,7 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"set_memory_limit", "set_memory_limit",
&metal::set_memory_limit, &mx::metal::set_memory_limit,
"limit"_a, "limit"_a,
nb::kw_only(), nb::kw_only(),
"relaxed"_a = true, "relaxed"_a = true,
@ -81,7 +80,7 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"set_cache_limit", "set_cache_limit",
&metal::set_cache_limit, &mx::metal::set_cache_limit,
"limit"_a, "limit"_a,
R"pbdoc( R"pbdoc(
Set the free cache limit. Set the free cache limit.
@ -101,7 +100,7 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"set_wired_limit", "set_wired_limit",
&metal::set_wired_limit, &mx::metal::set_wired_limit,
"limit"_a, "limit"_a,
R"pbdoc( R"pbdoc(
Set the wired size limit. Set the wired size limit.
@ -133,7 +132,7 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"clear_cache", "clear_cache",
&metal::clear_cache, &mx::metal::clear_cache,
R"pbdoc( R"pbdoc(
Clear the memory cache. Clear the memory cache.
@ -142,7 +141,7 @@ void init_metal(nb::module_& m) {
metal.def( metal.def(
"start_capture", "start_capture",
&metal::start_capture, &mx::metal::start_capture,
"path"_a, "path"_a,
R"pbdoc( R"pbdoc(
Start a Metal capture. Start a Metal capture.
@ -153,13 +152,13 @@ void init_metal(nb::module_& m) {
)pbdoc"); )pbdoc");
metal.def( metal.def(
"stop_capture", "stop_capture",
&metal::stop_capture, &mx::metal::stop_capture,
R"pbdoc( R"pbdoc(
Stop a Metal capture. Stop a Metal capture.
)pbdoc"); )pbdoc");
metal.def( metal.def(
"device_info", "device_info",
&metal::device_info, &mx::metal::device_info,
R"pbdoc( R"pbdoc(
Get information about the GPU device and system settings. Get information about the GPU device and system settings.

File diff suppressed because it is too large Load Diff

View File

@ -12,23 +12,22 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/random.h" #include "mlx/random.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence { class PyKeySequence {
public: public:
explicit PyKeySequence(uint64_t seed) { explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed)); state_.append(mx::random::key(seed));
} }
void seed(uint64_t seed) { void seed(uint64_t seed) {
state_[0] = key(seed); state_[0] = mx::random::key(seed);
} }
array next() { mx::array next() {
auto out = split(nb::cast<array>(state_[0])); auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
state_[0] = out.first; state_[0] = out.first;
return out.second; return out.second;
} }
@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"key", "key",
&key, &mx::random::key,
"seed"_a, "seed"_a,
R"pbdoc( R"pbdoc(
Get a PRNG key from a seed. Get a PRNG key from a seed.
@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"split", "split",
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split), nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
&mx::random::split),
"key"_a, "key"_a,
"num"_a = 2, "num"_a = 2,
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return uniform( return mx::random::uniform(
to_array(low), to_array(low),
to_array(high), to_array(high),
shape, shape,
type.value_or(float32), type.value_or(mx::float32),
key, key,
s); s);
}, },
"low"_a = 0, "low"_a = 0,
"high"_a = 1, "high"_a = 1,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
m.def( m.def(
"normal", "normal",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
float loc, float loc,
float scale, float scale,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s); return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"loc"_a = 0.0, "loc"_a = 0.0,
"scale"_a = 1.0, "scale"_a = 1.0,
"key"_a = nb::none(), "key"_a = nb::none(),
@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"multivariate_normal", "multivariate_normal",
[](const array& mean, [](const mx::array& mean,
const array& cov, const mx::array& cov,
const std::vector<int>& shape, const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return multivariate_normal( return mx::random::multivariate_normal(
mean, cov, shape, type.value_or(float32), key, s); mean, cov, shape, type.value_or(mx::float32), key, s);
}, },
"mean"_a, "mean"_a,
"cov"_a, "cov"_a,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return randint( return mx::random::randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s); to_array(low),
to_array(high),
shape,
type.value_or(mx::int32),
key,
s);
}, },
"low"_a, "low"_a,
"high"_a, "high"_a,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = int32, "dtype"_a.none() = mx::int32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
"bernoulli", "bernoulli",
[](const ScalarOrArray& p_, [](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape, const std::optional<std::vector<int>> shape,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_); auto p = to_array(p_);
if (shape.has_value()) { if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s); return mx::random::bernoulli(p, shape.value(), key, s);
} else { } else {
return bernoulli(p, key, s); return mx::random::bernoulli(p, key, s);
} }
}, },
"p"_a = 0.5, "p"_a = 0.5,
@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& lower_, [](const ScalarOrArray& lower_,
const ScalarOrArray& upper_, const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_, const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_); auto lower = to_array(lower_);
auto upper = to_array(upper_); auto upper = to_array(upper_);
auto t = type.value_or(float32); auto t = type.value_or(mx::float32);
if (shape_.has_value()) { if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), t, key, s); return mx::random::truncated_normal(
lower, upper, shape_.value(), t, key, s);
} else { } else {
return truncated_normal(lower, upper, t, key, s); return mx::random::truncated_normal(lower, upper, t, key, s);
} }
}, },
"lower"_a, "lower"_a,
"upper"_a, "upper"_a,
"shape"_a = nb::none(), "shape"_a = nb::none(),
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
m.def( m.def(
"gumbel", "gumbel",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s); return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"categorical", "categorical",
[](const array& logits, [](const mx::array& logits,
int axis, int axis,
const std::optional<std::vector<int>> shape, const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples, const std::optional<int> num_samples,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) { if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified."); "[categorical] At most one of shape or num_samples can be specified.");
} else if (shape.has_value()) { } else if (shape.has_value()) {
return categorical(logits, axis, shape.value(), key, s); return mx::random::categorical(logits, axis, shape.value(), key, s);
} else if (num_samples.has_value()) { } else if (num_samples.has_value()) {
return categorical(logits, axis, num_samples.value(), key, s); return mx::random::categorical(
logits, axis, num_samples.value(), key, s);
} else { } else {
return categorical(logits, axis, key, s); return mx::random::categorical(logits, axis, key, s);
} }
}, },
"logits"_a, "logits"_a,
@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
m.def( m.def(
"laplace", "laplace",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<mx::Dtype> type,
float loc, float loc,
float scale, float scale,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return laplace(shape, type.value_or(float32), loc, scale, key, s); return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = mx::float32,
"loc"_a = 0.0, "loc"_a = 0.0,
"scale"_a = 1.0, "scale"_a = 1.0,
"key"_a = nb::none(), "key"_a = nb::none(),
@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"permuation", "permuation",
[](const std::variant<nb::int_, array>& x, [](const std::variant<nb::int_, mx::array>& x,
int axis, int axis,
const std::optional<array>& key_, const std::optional<mx::array>& key_,
StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<nb::int_>(&x); pv) { if (auto pv = std::get_if<nb::int_>(&x); pv) {
return permutation(nb::cast<int>(*pv), key, s); return mx::random::permutation(nb::cast<int>(*pv), key, s);
} else { } else {
return permutation(std::get<array>(x), axis, key, s); return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
} }
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},

View File

@ -10,14 +10,14 @@
#include "mlx/stream.h" #include "mlx/stream.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
// Create the StreamContext on enter and delete on exit. // Create the StreamContext on enter and delete on exit.
class PyStreamContext { class PyStreamContext {
public: public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) { PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) { if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error( throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device."); "[StreamContext] Invalid argument, please specify a stream or device.");
@ -26,7 +26,7 @@ class PyStreamContext {
} }
void enter() { void enter() {
_inner = new StreamContext(_s); _inner = new mx::StreamContext(_s);
} }
void exit() { void exit() {
@ -37,39 +37,40 @@ class PyStreamContext {
} }
private: private:
StreamOrDevice _s; mx::StreamOrDevice _s;
StreamContext* _inner; mx::StreamContext* _inner;
}; };
void init_stream(nb::module_& m) { void init_stream(nb::module_& m) {
nb::class_<Stream>( nb::class_<mx::Stream>(
m, m,
"Stream", "Stream",
R"pbdoc( R"pbdoc(
A stream for running operations on a given device. A stream for running operations on a given device.
)pbdoc") )pbdoc")
.def_ro("device", &Stream::device) .def_ro("device", &mx::Stream::device)
.def( .def(
"__repr__", "__repr__",
[](const Stream& s) { [](const mx::Stream& s) {
std::ostringstream os; std::ostringstream os;
os << s; os << s;
return os.str(); return os.str();
}) })
.def("__eq__", [](const Stream& s, const nb::object& other) { .def("__eq__", [](const mx::Stream& s, const nb::object& other) {
return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other); return nb::isinstance<mx::Stream>(other) &&
s == nb::cast<mx::Stream>(other);
}); });
nb::implicitly_convertible<Device::DeviceType, Device>(); nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def( m.def(
"default_stream", "default_stream",
&default_stream, &mx::default_stream,
"device"_a, "device"_a,
R"pbdoc(Get the device's default stream.)pbdoc"); R"pbdoc(Get the device's default stream.)pbdoc");
m.def( m.def(
"set_default_stream", "set_default_stream",
&set_default_stream, &mx::set_default_stream,
"stream"_a, "stream"_a,
R"pbdoc( R"pbdoc(
Set the default stream. Set the default stream.
@ -82,7 +83,7 @@ void init_stream(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"new_stream", "new_stream",
&new_stream, &mx::new_stream,
"device"_a, "device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc"); R"pbdoc(Make a new stream on the given device.)pbdoc");
@ -94,7 +95,7 @@ void init_stream(nb::module_& m) {
Args: Args:
s: The stream or device to set as the default. s: The stream or device to set as the default.
)pbdoc") )pbdoc")
.def(nb::init<StreamOrDevice>(), "s"_a) .def(nb::init<mx::StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) .def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def( .def(
"__exit__", "__exit__",
@ -107,7 +108,7 @@ void init_stream(nb::module_& m) {
"traceback"_a = nb::none()); "traceback"_a = nb::none());
m.def( m.def(
"stream", "stream",
[](StreamOrDevice s) { return PyStreamContext(s); }, [](mx::StreamOrDevice s) { return PyStreamContext(s); },
"s"_a, "s"_a,
R"pbdoc( R"pbdoc(
Create a context manager to set the default device and stream. Create a context manager to set the default device and stream.
@ -131,8 +132,8 @@ void init_stream(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"synchronize", "synchronize",
[](const std::optional<Stream>& s) { [](const std::optional<mx::Stream>& s) {
s ? synchronize(s.value()) : synchronize(); s ? mx::synchronize(s.value()) : mx::synchronize();
}, },
"stream"_a = nb::none(), "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(

View File

@ -20,9 +20,12 @@
#include "mlx/utils.h" #include "mlx/utils.h"
#include "python/src/trees.h" #include "python/src/trees.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
// Needed for printing shapes and strides.
using mx::operator<<;
using IntOrVec = std::variant<int, std::vector<int>>; using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>; using StrOrVec = std::variant<std::string, std::vector<std::string>>;
@ -108,7 +111,7 @@ auto py_value_and_grad(
} }
// Collect the arrays // Collect the arrays
std::vector<array> arrays; std::vector<mx::array> arrays;
std::vector<int> counts(1, 0); std::vector<int> counts(1, 0);
for (auto i : argnums) { for (auto i : argnums) {
auto argsi = tree_flatten(args[i]); auto argsi = tree_flatten(args[i]);
@ -127,7 +130,7 @@ auto py_value_and_grad(
// value_out will hold the output of the python function in order to be // value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values // able to reconstruct the python tree of extra return values
nb::object py_value_out; nb::object py_value_out;
auto value_and_grads = value_and_grad( auto value_and_grads = mx::value_and_grad(
[&fun, [&fun,
&args, &args,
&kwargs, &kwargs,
@ -136,7 +139,7 @@ auto py_value_and_grad(
&counts, &counts,
&py_value_out, &py_value_out,
&error_msg_tag, &error_msg_tag,
scalar_func_only](const std::vector<array>& a) { scalar_func_only](const std::vector<mx::array>& a) {
// Copy the arguments // Copy the arguments
nb::list args_cpy; nb::list args_cpy;
nb::kwargs kwargs_cpy = nb::kwargs(); nb::kwargs kwargs_cpy = nb::kwargs();
@ -165,7 +168,7 @@ auto py_value_and_grad(
py_value_out = fun(*args_cpy, **kwargs_cpy); py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function // Validate the return value of the python function
if (!nb::isinstance<array>(py_value_out)) { if (!nb::isinstance<mx::array>(py_value_out)) {
if (scalar_func_only) { if (scalar_func_only) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
@ -193,7 +196,7 @@ auto py_value_and_grad(
<< "we got an empty tuple."; << "we got an empty tuple.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!nb::isinstance<array>(ret[0])) { if (!nb::isinstance<mx::array>(ret[0])) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a " << "whose gradient we want to compute should be either a "
@ -275,12 +278,12 @@ auto py_vmap(
{tree, axes}, {tree, axes},
[&flat_axes, &encountered_tuple, output_axes]( [&flat_axes, &encountered_tuple, output_axes](
const std::vector<nb::object>& inputs) { const std::vector<nb::object>& inputs) {
if (nb::isinstance<array>(inputs[0])) { if (nb::isinstance<mx::array>(inputs[0])) {
if (inputs[1].is_none()) { if (inputs[1].is_none()) {
flat_axes.push_back(-1); flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(inputs[1])) { } else if (nb::isinstance<nb::int_>(inputs[1])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1])); int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
const array& x = nb::cast<array>(inputs[0]); const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) { if (axis < 0) {
axis += x.ndim() + output_axes; axis += x.ndim() + output_axes;
} }
@ -297,7 +300,7 @@ auto py_vmap(
auto l = nb::cast<nb::tuple>(inputs[1]); auto l = nb::cast<nb::tuple>(inputs[1]);
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) { if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0])); int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
const array& x = nb::cast<array>(inputs[0]); const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) { if (axis < 0) {
axis += x.ndim() + output_axes; axis += x.ndim() + output_axes;
} }
@ -323,7 +326,7 @@ auto py_vmap(
"[vmap] The arguments should contain only arrays"); "[vmap] The arguments should contain only arrays");
} }
}); });
if (encountered_tuple && !nb::isinstance<array>(tree)) { if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
throw std::invalid_argument("[vmap] axis must be int or None."); throw std::invalid_argument("[vmap] axis must be int or None.");
} }
return flat_axes; return flat_axes;
@ -339,7 +342,7 @@ auto py_vmap(
nb::object py_outputs; nb::object py_outputs;
auto vmap_fn = auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) { [&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
// Call the python function // Call the python function
py_outputs = fun(*tree_unflatten(args, a)); py_outputs = fun(*tree_unflatten(args, a));
@ -348,12 +351,12 @@ auto py_vmap(
}; };
auto [trace_inputs, trace_outputs] = auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes); mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true); auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
// Perform the vmap // Perform the vmap
auto outputs = detail::vmap_replace( auto outputs = mx::detail::vmap_replace(
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes); inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
// Put the outputs back in the container // Put the outputs back in the container
@ -401,7 +404,7 @@ struct PyCompiledFun {
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
// Flat array inputs // Flat array inputs
std::vector<array> inputs; std::vector<mx::array> inputs;
// Compilation constants which includes the tree structure of the arguments // Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants; std::vector<uint64_t> constants;
@ -437,8 +440,8 @@ struct PyCompiledFun {
constants.push_back(nb::cast<int64_t>(r)); constants.push_back(nb::cast<int64_t>(r));
recurse(item.second); recurse(item.second);
} }
} else if (nb::isinstance<array>(obj)) { } else if (nb::isinstance<mx::array>(obj)) {
inputs.push_back(nb::cast<array>(obj)); inputs.push_back(nb::cast<mx::array>(obj));
constants.push_back(array_identifier); constants.push_back(array_identifier);
} else if (nb::isinstance<nb::str>(obj)) { } else if (nb::isinstance<nb::str>(obj)) {
auto r = obj.attr("__hash__")(); auto r = obj.attr("__hash__")();
@ -461,10 +464,10 @@ struct PyCompiledFun {
int num_args = inputs.size(); int num_args = inputs.size();
recurse(kwargs); recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args]( auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) { const std::vector<mx::array>& a) {
// Put tracers into captured inputs // Put tracers into captured inputs
std::vector<array> flat_in_captures; std::vector<mx::array> flat_in_captures;
std::vector<array> trace_captures; std::vector<mx::array> trace_captures;
if (!captured_inputs.is_none()) { if (!captured_inputs.is_none()) {
flat_in_captures = tree_flatten(captured_inputs, false); flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert( trace_captures.insert(
@ -505,9 +508,9 @@ struct PyCompiledFun {
// Compile and call // Compile and call
auto outputs = auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!captured_outputs.is_none()) { if (!captured_outputs.is_none()) {
std::vector<array> captures( std::vector<mx::array> captures(
std::make_move_iterator(outputs.begin() + num_outputs), std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end())); std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures); tree_fill(captured_outputs, captures);
@ -526,7 +529,7 @@ struct PyCompiledFun {
nb::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id); tree_cache().erase(fun_id);
detail::compile_erase(fun_id); mx::detail::compile_erase(fun_id);
fun.release().dec_ref(); fun.release().dec_ref();
captured_inputs.release().dec_ref(); captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref(); captured_outputs.release().dec_ref();
@ -561,7 +564,7 @@ class PyCheckpointedFun {
args_structure_.release().dec_ref(); args_structure_.release().dec_ref();
} }
std::vector<array> operator()(const std::vector<array>& inputs) { std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
auto args = nb::cast<nb::tuple>( auto args = nb::cast<nb::tuple>(
tree_unflatten_from_structure(args_structure_, inputs)); tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] = auto [outputs, output_structure] =
@ -579,7 +582,7 @@ class PyCheckpointedFun {
auto [inputs, args_structure] = auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false); tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint( auto outputs = mx::checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs); InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs); return tree_unflatten_from_structure(*output_structure, outputs);
@ -660,12 +663,12 @@ class PyCustomFunction {
} }
} }
std::vector<array> operator()(const std::vector<array>& inputs) { std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
nb::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>( auto new_inputs = nb::cast<nb::tuple>(
tree_unflatten_from_structure(input_structure_, inputs)); tree_unflatten_from_structure(input_structure_, inputs));
std::vector<array> outputs; std::vector<mx::array> outputs;
std::tie(outputs, *output_structure_) = std::tie(outputs, *output_structure_) =
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1])); tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
return outputs; return outputs;
@ -694,10 +697,10 @@ class PyCustomFunction {
} }
} }
std::vector<array> operator()( std::vector<mx::array> operator()(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& cotangents, const std::vector<mx::array>& cotangents,
const std::vector<array>& outputs) { const std::vector<mx::array>& outputs) {
nb::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>( auto new_inputs = nb::cast<nb::tuple>(
@ -734,9 +737,9 @@ class PyCustomFunction {
input_structure_.release().dec_ref(); input_structure_.release().dec_ref();
} }
std::vector<array> operator()( std::vector<mx::array> operator()(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& tangents, const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
nb::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
@ -759,7 +762,7 @@ class PyCustomFunction {
int tangent_index = 0; int tangent_index = 0;
auto new_tangents = auto new_tangents =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) { nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
if (nb::isinstance<array>(element) && if (nb::isinstance<mx::array>(element) &&
have_tangents[array_index++]) { have_tangents[array_index++]) {
return nb::cast(tangents[tangent_index++]); return nb::cast(tangents[tangent_index++]);
} else { } else {
@ -789,8 +792,8 @@ class PyCustomFunction {
input_structure_.release().dec_ref(); input_structure_.release().dec_ref();
} }
std::pair<std::vector<array>, std::vector<int>> operator()( std::pair<std::vector<mx::array>, std::vector<int>> operator()(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
nb::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
@ -807,7 +810,7 @@ class PyCustomFunction {
auto new_axes = auto new_axes =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) { nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
int axis = axes[arr_index++]; int axis = axes[arr_index++];
if (nb::isinstance<array>(element) && axis >= 0) { if (nb::isinstance<mx::array>(element) && axis >= 0) {
return nb::cast(axis); return nb::cast(axis);
} else { } else {
return nb::none(); return nb::none();
@ -831,11 +834,11 @@ class PyCustomFunction {
"[custom vmap] Vmap function should return a tuple with 2 items."); "[custom vmap] Vmap function should return a tuple with 2 items.");
} }
std::vector<array> outputs; std::vector<mx::array> outputs;
std::vector<int> output_axes; std::vector<int> output_axes;
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) { tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
if (nb::isinstance<array>(objects[0])) { if (nb::isinstance<mx::array>(objects[0])) {
outputs.push_back(nb::cast<array>(objects[0])); outputs.push_back(nb::cast<mx::array>(objects[0]));
output_axes.push_back( output_axes.push_back(
objects[1].is_none() ? -1 : nb::cast<int>(objects[1])); objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
} }
@ -852,7 +855,7 @@ class PyCustomFunction {
} }
// Extract the inputs and their structure in capturable vars // Extract the inputs and their structure in capturable vars
std::vector<array> input_arrays; std::vector<mx::array> input_arrays;
nb::object input_structure; nb::object input_structure;
auto full_args = nb::make_tuple(args, kwargs); auto full_args = nb::make_tuple(args, kwargs);
std::tie(input_arrays, input_structure) = std::tie(input_arrays, input_structure) =
@ -864,7 +867,7 @@ class PyCustomFunction {
// Make a function that calls fun_ in the forward pass and vjp_ in the // Make a function that calls fun_ in the forward pass and vjp_ in the
// backward pass. Then call it immediately and return the results. // backward pass. Then call it immediately and return the results.
auto f = custom_function( auto f = mx::custom_function(
InnerFunction(fun_, input_structure, output_structure), InnerFunction(fun_, input_structure, output_structure),
make_vjp_function(input_structure, output_structure), make_vjp_function(input_structure, output_structure),
make_jvp_function(input_structure), make_jvp_function(input_structure),
@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
m.def( m.def(
"eval", "eval",
[](const nb::args& args) { [](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false); std::vector<mx::array> arrays = tree_flatten(args, false);
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
eval(arrays); eval(arrays);
@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
m.def( m.def(
"async_eval", "async_eval",
[](const nb::args& args) { [](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false); std::vector<mx::array> arrays = tree_flatten(args, false);
{ {
nb::gil_scoped_release nogil; nb::gil_scoped_release nogil;
async_eval(arrays); async_eval(arrays);
@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
m.def( m.def(
"jvp", "jvp",
[](const nb::callable& fun, [](const nb::callable& fun,
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& tangents) { const std::vector<mx::array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) { auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals)); auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) { if (nb::isinstance<mx::array>(out)) {
return std::vector<array>{nb::cast<array>(out)}; return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else { } else {
return nb::cast<std::vector<array>>(out); return nb::cast<std::vector<mx::array>>(out);
} }
}; };
return jvp(vfun, primals, tangents); return jvp(vfun, primals, tangents);
@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
m.def( m.def(
"vjp", "vjp",
[](const nb::callable& fun, [](const nb::callable& fun,
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& cotangents) { const std::vector<mx::array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) { auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals)); auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) { if (nb::isinstance<mx::array>(out)) {
return std::vector<array>{nb::cast<array>(out)}; return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else { } else {
return nb::cast<std::vector<array>>(out); return nb::cast<std::vector<mx::array>>(out);
} }
}; };
return vjp(vfun, primals, cotangents); return vjp(vfun, primals, cotangents);
@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
m.def( m.def(
"export_to_dot", "export_to_dot",
[](nb::object file, const nb::args& args) { [](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args); std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) { if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file)); std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays); export_to_dot(out, arrays);
@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"disable_compile", "disable_compile",
&disable_compile, &mx::disable_compile,
R"pbdoc( R"pbdoc(
Globally disable compilation. Setting the environment variable Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation. ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc"); )pbdoc");
m.def( m.def(
"enable_compile", "enable_compile",
&enable_compile, &mx::enable_compile,
R"pbdoc( R"pbdoc(
Globally enable compilation. This will override the environment Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set. variable ``MLX_DISABLE_COMPILE`` if set.
@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
auto atexit = nb::module_::import_("atexit"); auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear(); tree_cache().clear();
detail::compile_clear_cache(); mx::detail::compile_clear_cache();
})); }));
} }

View File

@ -188,7 +188,7 @@ void tree_visit_update(
d[item.first] = recurse(item.second); d[item.first] = recurse(item.second);
} }
return nb::cast<nb::object>(d); return nb::cast<nb::object>(d);
} else if (nb::isinstance<array>(subtree)) { } else if (nb::isinstance<mx::array>(subtree)) {
return visitor(subtree); return visitor(subtree);
} else { } else {
return nb::cast<nb::object>(subtree); return nb::cast<nb::object>(subtree);
@ -200,7 +200,7 @@ void tree_visit_update(
// Fill a pytree (recursive dict or list of dict or list) // Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays // in place with the given arrays
// Non dict or list nodes are ignored // Non dict or list nodes are ignored
void tree_fill(nb::object& tree, const std::vector<array>& values) { void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0; size_t index = 0;
tree_visit_update( tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); }); tree, [&](nb::handle node) { return nb::cast(values[index++]); });
@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
// Replace all the arrays from the src values with the dst values in the tree // Replace all the arrays from the src values with the dst values in the tree
void tree_replace( void tree_replace(
nb::object& tree, nb::object& tree,
const std::vector<array>& src, const std::vector<mx::array>& src,
const std::vector<array>& dst) { const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst; std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) { for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]}); src_to_dst.insert({src[i].id(), dst[i]});
} }
tree_visit_update(tree, [&](nb::handle node) { tree_visit_update(tree, [&](nb::handle node) {
auto arr = nb::cast<array>(node); auto arr = nb::cast<mx::array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return nb::cast(it->second); return nb::cast(it->second);
} }
@ -224,12 +224,12 @@ void tree_replace(
}); });
} }
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) { std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<array> flat_tree; std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) { tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) { if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj)); flat_tree.push_back(nb::cast<mx::array>(obj));
} else if (strict) { } else if (strict) {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays"); "[tree_flatten] The argument should contain only arrays");
@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
nb::object tree_unflatten( nb::object tree_unflatten(
nb::object tree, nb::object tree,
const std::vector<array>& values, const std::vector<mx::array>& values,
int index /* = 0 */) { int index /* = 0 */) {
return tree_map(tree, [&](nb::handle obj) { return tree_map(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) { if (nb::isinstance<mx::array>(obj)) {
return nb::cast(values[index++]); return nb::cast(values[index++]);
} else { } else {
return nb::cast<nb::object>(obj); return nb::cast<nb::object>(obj);
@ -265,16 +265,16 @@ nb::object structure_sentinel() {
return sentinel; return sentinel;
} }
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure( std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree, nb::object tree,
bool strict /* = true */) { bool strict /* = true */) {
auto sentinel = structure_sentinel(); auto sentinel = structure_sentinel();
std::vector<array> flat_tree; std::vector<mx::array> flat_tree;
auto structure = tree_map( auto structure = tree_map(
tree, tree,
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (nb::isinstance<array>(obj)) { if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj)); flat_tree.push_back(nb::cast<mx::array>(obj));
return sentinel; return sentinel;
} else if (!strict) { } else if (!strict) {
return nb::cast<nb::object>(obj); return nb::cast<nb::object>(obj);
@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
nb::object tree_unflatten_from_structure( nb::object tree_unflatten_from_structure(
nb::object structure, nb::object structure,
const std::vector<array>& values, const std::vector<mx::array>& values,
int index /* = 0 */) { int index /* = 0 */) {
auto sentinel = structure_sentinel(); auto sentinel = structure_sentinel();
return tree_map(structure, [&](nb::handle obj) { return tree_map(structure, [&](nb::handle obj) {

View File

@ -4,8 +4,8 @@
#include "mlx/array.h" #include "mlx/array.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
void tree_visit( void tree_visit(
const std::vector<nb::object>& trees, const std::vector<nb::object>& trees,
@ -27,7 +27,7 @@ void tree_visit_update(
/** /**
* Fill a pytree (recursive dict or list of dict or list) in place with the * Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */ * given arrays. */
void tree_fill(nb::object& tree, const std::vector<array>& values); void tree_fill(nb::object& tree, const std::vector<mx::array>& values);
/** /**
* Replace all the arrays from the src values with the dst values in the * Replace all the arrays from the src values with the dst values in the
@ -35,28 +35,28 @@ void tree_fill(nb::object& tree, const std::vector<array>& values);
*/ */
void tree_replace( void tree_replace(
nb::object& tree, nb::object& tree,
const std::vector<array>& src, const std::vector<mx::array>& src,
const std::vector<array>& dst); const std::vector<mx::array>& dst);
/** /**
* Flatten a tree into a vector of arrays. If strict is true, then the * Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array. * function will throw if the tree contains a leaf which is not an array.
*/ */
std::vector<array> tree_flatten(nb::object tree, bool strict = true); std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
/** /**
* Unflatten a tree from a vector of arrays. * Unflatten a tree from a vector of arrays.
*/ */
nb::object tree_unflatten( nb::object tree_unflatten(
nb::object tree, nb::object tree,
const std::vector<array>& values, const std::vector<mx::array>& values,
int index = 0); int index = 0);
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure( std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree, nb::object tree,
bool strict = true); bool strict = true);
nb::object tree_unflatten_from_structure( nb::object tree_unflatten_from_structure(
nb::object structure, nb::object structure,
const std::vector<array>& values, const std::vector<mx::array>& values,
int index = 0); int index = 0);

View File

@ -4,22 +4,24 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "python/src/convert.h" #include "python/src/convert.h"
array to_array( mx::array to_array(
const ScalarOrArray& v, const ScalarOrArray& v,
std::optional<Dtype> dtype /* = std::nullopt */) { std::optional<mx::Dtype> dtype /* = std::nullopt */) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) { if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), dtype.value_or(bool_)); return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) { } else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32); auto out_t = dtype.value_or(mx::int32);
// bool_ is an exception and is always promoted // bool_ is an exception and is always promoted
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t); return mx::array(
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) { } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32); auto out_t = dtype.value_or(mx::float32);
return array( return mx::array(
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32); nb::cast<float>(*pv),
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64); return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
} else if (auto pv = std::get_if<array>(&v); pv) { } else if (auto pv = std::get_if<mx::array>(&v); pv) {
return *pv; return *pv;
} else if (auto pv = std::get_if< } else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v); nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
@ -30,7 +32,7 @@ array to_array(
} }
} }
std::pair<array, array> to_arrays( std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a, const ScalarOrArray& a,
const ScalarOrArray& b) { const ScalarOrArray& b) {
// Four cases: // Four cases:
@ -39,15 +41,15 @@ std::pair<array, array> to_arrays(
// - If b is an array but a is not, treat a as a weak python type // - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone // - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) { auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<array>(x) || return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<nb::object>(x) && std::holds_alternative<nb::object>(x) &&
nb::hasattr(std::get<nb::object>(x), "__mlx_array__"); nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
}; };
auto get_mlx_array = [](const ScalarOrArray& x) { auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<array>(&x); px) { if (auto px = std::get_if<mx::array>(&x); px) {
return *px; return *px;
} else { } else {
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__")); return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
} }
}; };
@ -66,11 +68,11 @@ std::pair<array, array> to_arrays(
} }
} }
array to_array_with_accessor(nb::object obj) { mx::array to_array_with_accessor(nb::object obj) {
if (nb::isinstance<array>(obj)) { if (nb::isinstance<mx::array>(obj)) {
return nb::cast<array>(obj); return nb::cast<mx::array>(obj);
} else if (nb::hasattr(obj, "__mlx_array__")) { } else if (nb::hasattr(obj, "__mlx_array__")) {
return nb::cast<array>(obj.attr("__mlx_array__")()); return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str() msg << "Invalid type " << nb::type_name(obj.type()).c_str()

View File

@ -12,17 +12,16 @@
#include "mlx/array.h" #include "mlx/array.h"
namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>; using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::variant< using ScalarOrArray = std::variant<
nb::bool_, nb::bool_,
nb::int_, nb::int_,
nb::float_, nb::float_,
// Must be above ndarray // Must be above ndarray
array, mx::array,
// Must be above complex // Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>, std::complex<float>,
@ -45,7 +44,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an // Checks if the value can be compared to an array (or is already an
// mlx array) // mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) { if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__"); return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
} else { } else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array // and can be compared to an array
@ -66,12 +65,12 @@ inline void throw_invalid_operation(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
array to_array( mx::array to_array(
const ScalarOrArray& v, const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt); std::optional<mx::Dtype> dtype = std::nullopt);
std::pair<array, array> to_arrays( std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a, const ScalarOrArray& a,
const ScalarOrArray& b); const ScalarOrArray& b);
array to_array_with_accessor(nb::object obj); mx::array to_array_with_accessor(nb::object obj);