Remove "using namespace mlx::core" in python/src (#1689)

This commit is contained in:
Cheng 2024-12-12 08:45:39 +09:00 committed by GitHub
parent f3dfa36a3a
commit 0bf19037ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1423 additions and 1302 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,37 +14,37 @@
#define Py_bf_releasebuffer 2
#endif
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
std::string buffer_format(const array& a) {
std::string buffer_format(const mx::array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case bool_:
case mx::bool_:
return "?";
case uint8:
case mx::uint8:
return "B";
case uint16:
case mx::uint16:
return "H";
case uint32:
case mx::uint32:
return "I";
case uint64:
case mx::uint64:
return "Q";
case int8:
case mx::int8:
return "b";
case int16:
case mx::int16:
return "h";
case int32:
case mx::int32:
return "i";
case int64:
case mx::int64:
return "q";
case float16:
case mx::float16:
return "e";
case float32:
case mx::float32:
return "f";
case bfloat16:
case mx::bfloat16:
return "B";
case complex64:
case mx::complex64:
return "Zf\0";
default: {
std::ostringstream os;
@ -84,7 +84,7 @@ struct buffer_info {
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<array>(nb::handle(obj));
auto a = nb::cast<mx::array>(nb::handle(obj));
{
nb::gil_scoped_release nogil;

View File

@ -16,7 +16,7 @@ enum PyScalarT {
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
struct ndarray_traits<mx::float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) {
}
template <typename T>
array nd_array_to_mlx_contiguous(
mx::array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const Shape& shape,
Dtype dtype) {
const mx::Shape& shape,
mx::Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
auto data_ptr = nd_array.data();
return array(static_cast<const T*>(data_ptr), shape, dtype);
return mx::array(static_cast<const T*>(data_ptr), shape, dtype);
}
array nd_array_to_mlx(
mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
std::optional<mx::Dtype> dtype) {
// Compute the shape and size
Shape shape;
mx::Shape shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
@ -59,49 +59,49 @@ array nd_array_to_mlx(
// Copy data and make array
if (type == nb::dtype<bool>()) {
return nd_array_to_mlx_contiguous<bool>(
nd_array, shape, dtype.value_or(bool_));
nd_array, shape, dtype.value_or(mx::bool_));
} else if (type == nb::dtype<uint8_t>()) {
return nd_array_to_mlx_contiguous<uint8_t>(
nd_array, shape, dtype.value_or(uint8));
nd_array, shape, dtype.value_or(mx::uint8));
} else if (type == nb::dtype<uint16_t>()) {
return nd_array_to_mlx_contiguous<uint16_t>(
nd_array, shape, dtype.value_or(uint16));
nd_array, shape, dtype.value_or(mx::uint16));
} else if (type == nb::dtype<uint32_t>()) {
return nd_array_to_mlx_contiguous<uint32_t>(
nd_array, shape, dtype.value_or(uint32));
nd_array, shape, dtype.value_or(mx::uint32));
} else if (type == nb::dtype<uint64_t>()) {
return nd_array_to_mlx_contiguous<uint64_t>(
nd_array, shape, dtype.value_or(uint64));
nd_array, shape, dtype.value_or(mx::uint64));
} else if (type == nb::dtype<int8_t>()) {
return nd_array_to_mlx_contiguous<int8_t>(
nd_array, shape, dtype.value_or(int8));
nd_array, shape, dtype.value_or(mx::int8));
} else if (type == nb::dtype<int16_t>()) {
return nd_array_to_mlx_contiguous<int16_t>(
nd_array, shape, dtype.value_or(int16));
nd_array, shape, dtype.value_or(mx::int16));
} else if (type == nb::dtype<int32_t>()) {
return nd_array_to_mlx_contiguous<int32_t>(
nd_array, shape, dtype.value_or(int32));
nd_array, shape, dtype.value_or(mx::int32));
} else if (type == nb::dtype<int64_t>()) {
return nd_array_to_mlx_contiguous<int64_t>(
nd_array, shape, dtype.value_or(int64));
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
nd_array, shape, dtype.value_or(mx::int64));
} else if (type == nb::dtype<mx::float16_t>()) {
return nd_array_to_mlx_contiguous<mx::float16_t>(
nd_array, shape, dtype.value_or(mx::float16));
} else if (type == nb::bfloat16) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
nd_array, shape, dtype.value_or(mx::bfloat16));
} else if (type == nb::dtype<float>()) {
return nd_array_to_mlx_contiguous<float>(
nd_array, shape, dtype.value_or(float32));
nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<double>()) {
return nd_array_to_mlx_contiguous<double>(
nd_array, shape, dtype.value_or(float32));
nd_array, shape, dtype.value_or(mx::float32));
} else if (type == nb::dtype<std::complex<float>>()) {
return nd_array_to_mlx_contiguous<complex64_t>(
nd_array, shape, dtype.value_or(complex64));
return nd_array_to_mlx_contiguous<mx::complex64_t>(
nd_array, shape, dtype.value_or(mx::complex64));
} else if (type == nb::dtype<std::complex<double>>()) {
return nd_array_to_mlx_contiguous<complex128_t>(
nd_array, shape, dtype.value_or(complex64));
return nd_array_to_mlx_contiguous<mx::complex128_t>(
nd_array, shape, dtype.value_or(mx::complex64));
} else {
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
}
@ -109,7 +109,7 @@ array nd_array_to_mlx(
template <typename T, typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
array a,
mx::array a,
std::optional<nb::dlpack::dtype> t = {}) {
{
nb::gil_scoped_release nogil;
@ -126,48 +126,48 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
}
template <typename... NDParams>
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
switch (a.dtype()) {
case bool_:
case mx::bool_:
return mlx_to_nd_array_impl<bool, NDParams...>(a);
case uint8:
case mx::uint8:
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
case uint16:
case mx::uint16:
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
case uint32:
case mx::uint32:
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
case uint64:
case mx::uint64:
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
case int8:
case mx::int8:
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
case int16:
case mx::int16:
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
case int32:
case mx::int32:
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
case int64:
case mx::int64:
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
case float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
case bfloat16:
case mx::float16:
return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);
case mx::bfloat16:
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case float32:
case mx::float32:
return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64:
case mx::complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default:
throw nb::type_error("type cannot be converted to NumPy.");
}
}
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {
return mlx_to_nd_array<nb::numpy>(a);
}
nb::ndarray<> mlx_to_dlpack(const array& a) {
nb::ndarray<> mlx_to_dlpack(const mx::array& a) {
return mlx_to_nd_array<>(a);
}
nb::object to_scalar(array& a) {
nb::object to_scalar(mx::array& a) {
if (a.size() != 1) {
throw std::invalid_argument(
"[convert] Only length-1 arrays can be converted to Python scalars.");
@ -177,31 +177,31 @@ nb::object to_scalar(array& a) {
a.eval();
}
switch (a.dtype()) {
case bool_:
case mx::bool_:
return nb::cast(a.item<bool>());
case uint8:
case mx::uint8:
return nb::cast(a.item<uint8_t>());
case uint16:
case mx::uint16:
return nb::cast(a.item<uint16_t>());
case uint32:
case mx::uint32:
return nb::cast(a.item<uint32_t>());
case uint64:
case mx::uint64:
return nb::cast(a.item<uint64_t>());
case int8:
case mx::int8:
return nb::cast(a.item<int8_t>());
case int16:
case mx::int16:
return nb::cast(a.item<int16_t>());
case int32:
case mx::int32:
return nb::cast(a.item<int32_t>());
case int64:
case mx::int64:
return nb::cast(a.item<int64_t>());
case float16:
return nb::cast(static_cast<float>(a.item<float16_t>()));
case float32:
case mx::float16:
return nb::cast(static_cast<float>(a.item<mx::float16_t>()));
case mx::float32:
return nb::cast(a.item<float>());
case bfloat16:
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
case mx::bfloat16:
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
case mx::complex64:
return nb::cast(a.item<std::complex<float>>());
default:
throw nb::type_error("type cannot be converted to Python scalar.");
@ -209,7 +209,7 @@ nb::object to_scalar(array& a) {
}
template <typename T, typename U = T>
nb::list to_list(array& a, size_t index, int dim) {
nb::list to_list(mx::array& a, size_t index, int dim) {
nb::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) {
return pl;
}
nb::object tolist(array& a) {
nb::object tolist(mx::array& a) {
if (a.ndim() == 0) {
return to_scalar(a);
}
@ -232,31 +232,31 @@ nb::object tolist(array& a) {
a.eval();
}
switch (a.dtype()) {
case bool_:
case mx::bool_:
return to_list<bool>(a, 0, 0);
case uint8:
case mx::uint8:
return to_list<uint8_t>(a, 0, 0);
case uint16:
case mx::uint16:
return to_list<uint16_t>(a, 0, 0);
case uint32:
case mx::uint32:
return to_list<uint32_t>(a, 0, 0);
case uint64:
case mx::uint64:
return to_list<uint64_t>(a, 0, 0);
case int8:
case mx::int8:
return to_list<int8_t>(a, 0, 0);
case int16:
case mx::int16:
return to_list<int16_t>(a, 0, 0);
case int32:
case mx::int32:
return to_list<int32_t>(a, 0, 0);
case int64:
case mx::int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t, float>(a, 0, 0);
case float32:
case mx::float16:
return to_list<mx::float16_t, float>(a, 0, 0);
case mx::float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
case mx::bfloat16:
return to_list<mx::bfloat16_t, float>(a, 0, 0);
case mx::complex64:
return to_list<std::complex<float>>(a, 0, 0);
default:
throw nb::type_error("data type cannot be converted to Python list.");
@ -279,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
template <typename T>
PyScalarT validate_shape(
T list,
const Shape& shape,
const mx::Shape& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
@ -307,9 +307,9 @@ PyScalarT validate_shape(
shape,
idx + 1,
all_python_primitive_elements);
} else if (nb::isinstance<array>(l)) {
} else if (nb::isinstance<mx::array>(l)) {
all_python_primitive_elements = false;
auto arr = nb::cast<array>(l);
auto arr = nb::cast<mx::array>(l);
if (arr.ndim() + idx + 1 == shape.size() &&
std::equal(
arr.shape().cbegin(),
@ -347,7 +347,7 @@ PyScalarT validate_shape(
}
template <typename T>
void get_shape(T list, Shape& shape) {
void get_shape(T list, mx::Shape& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) {
return get_shape(nb::cast<nb::list>(*l), shape);
} else if (nb::isinstance<nb::tuple>(*l)) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
} else if (nb::isinstance<mx::array>(*l)) {
auto arr = nb::cast<mx::array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(arr.shape(i));
}
@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) {
}
template <typename T>
array array_from_list_impl(
mx::array array_from_list_impl(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const Shape& shape) {
std::optional<mx::Dtype> specified_type,
const mx::Shape& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
std::vector<bool> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_));
return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));
}
case pyint: {
auto dtype = specified_type.value_or(int32);
if (dtype == int64) {
auto dtype = specified_type.value_or(mx::int32);
if (dtype == mx::int64) {
std::vector<int64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint64) {
return mx::array(vals.begin(), shape, dtype);
} else if (dtype == mx::uint64) {
std::vector<uint64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint32) {
return mx::array(vals.begin(), shape, dtype);
} else if (dtype == mx::uint32) {
std::vector<uint32_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (issubdtype(dtype, inexact)) {
return mx::array(vals.begin(), shape, dtype);
} else if (mx::issubdtype(dtype, mx::inexact)) {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
return mx::array(vals.begin(), shape, dtype);
} else {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
return mx::array(vals.begin(), shape, dtype);
}
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32));
return mx::array(
vals.begin(), shape, specified_type.value_or(mx::float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
return mx::array(
reinterpret_cast<mx::complex64_t*>(vals.data()),
shape,
specified_type.value_or(complex64));
specified_type.value_or(mx::complex64));
}
default: {
std::ostringstream msg;
@ -425,9 +426,9 @@ array array_from_list_impl(
}
template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
// Compute the shape
Shape shape;
mx::Shape shape;
get_shape(pl, shape);
// Validate the shape and type
@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
}
// `pl` contains mlx arrays
std::vector<array> arrays;
std::vector<mx::array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
}
return stack(arrays);
return mx::stack(arrays);
}
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array create_array(ArrayInitType v, std::optional<Dtype> t) {
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), t.value_or(bool_));
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return array(nb::cast<int>(*pv), t.value_or(int32));
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return array(nb::cast<float>(*pv), t.value_or(float32));
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional<Dtype> t) {
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype()));
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return mx::astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<nb::object>(v));
return astype(arr, t.value_or(arr.dtype()));
return mx::astype(arr, t.value_or(arr.dtype()));
}
}

View File

@ -9,15 +9,15 @@
#include "mlx/array.h"
#include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
array,
mx::array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
@ -25,17 +25,17 @@ using ArrayInitType = std::variant<
nb::tuple,
nb::object>;
array nd_array_to_mlx(
mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype);
std::optional<mx::Dtype> dtype);
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
nb::ndarray<> mlx_to_dlpack(const array& a);
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
nb::ndarray<> mlx_to_dlpack(const mx::array& a);
nb::object to_scalar(array& a);
nb::object to_scalar(mx::array& a);
nb::object tolist(array& a);
nb::object tolist(mx::array& a);
array create_array(ArrayInitType v, std::optional<Dtype> t);
array array_from_list(nb::list pl, std::optional<Dtype> dtype);
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype);
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);

View File

@ -8,51 +8,54 @@
#include "mlx/device.h"
#include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_device(nb::module_& m) {
auto device_class = nb::class_<Device>(
auto device_class = nb::class_<mx::Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
nb::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu)
nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
.value("cpu", mx::Device::DeviceType::cpu)
.value("gpu", mx::Device::DeviceType::gpu)
.export_values()
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
if (!nb::isinstance<Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) {
.def(
"__eq__",
[](const mx::Device::DeviceType& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<Device>(other);
return d == nb::cast<mx::Device>(other);
});
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &Device::type)
device_class
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &mx::Device::type)
.def(
"__repr__",
[](const Device& d) {
[](const mx::Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const Device& d, const nb::object& other) {
if (!nb::isinstance<Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) {
.def("__eq__", [](const mx::Device& d, const nb::object& other) {
if (!nb::isinstance<mx::Device>(other) &&
!nb::isinstance<mx::Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<Device>(other);
return d == nb::cast<mx::Device>(other);
});
nb::implicitly_convertible<Device::DeviceType, Device>();
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def(
"default_device",
&default_device,
&mx::default_device,
R"pbdoc(Get the default device.)pbdoc");
m.def(
"set_default_device",
&set_default_device,
&mx::set_default_device,
"device"_a,
R"pbdoc(Set the default device.)pbdoc");
}

View File

@ -9,26 +9,27 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_distributed(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"distributed", "mlx.core.distributed: Communication operations");
nb::class_<distributed::Group>(
nb::class_<mx::distributed::Group>(
m,
"Group",
R"pbcopy(
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
processes that can communicate.
)pbcopy")
.def("rank", &distributed::Group::rank, "Get the rank of this process")
.def("size", &distributed::Group::size, "Get the size of the group")
.def(
"rank", &mx::distributed::Group::rank, "Get the rank of this process")
.def("size", &mx::distributed::Group::size, "Get the size of the group")
.def(
"split",
&distributed::Group::split,
&mx::distributed::Group::split,
"color"_a,
"key"_a = -1,
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
@ -48,14 +49,14 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"is_available",
&distributed::is_available,
&mx::distributed::is_available,
R"pbdoc(
Check if a communication backend is available.
)pbdoc");
m.def(
"init",
&distributed::init,
&mx::distributed::init,
"strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"),
R"pbdoc(
@ -72,7 +73,7 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_sum",
&distributed::all_sum,
&mx::distributed::all_sum,
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@ -98,7 +99,7 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_gather",
&distributed::all_gather,
&mx::distributed::all_gather,
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@ -125,7 +126,7 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"send",
&distributed::send,
&mx::distributed::send,
"x"_a,
"dst"_a,
nb::kw_only(),
@ -152,7 +153,7 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"recv",
&distributed::recv,
&mx::distributed::recv,
"shape"_a,
"dtype"_a,
"src"_a,
@ -181,7 +182,7 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"recv_like",
&distributed::recv_like,
&mx::distributed::recv_like,
"x"_a,
"src"_a,
nb::kw_only(),

View File

@ -13,9 +13,9 @@
#include "mlx/fast.h"
#include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_fast(nb::module_& parent_module) {
auto m =
@ -23,7 +23,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rms_norm",
&fast::rms_norm,
&mx::fast::rms_norm,
"x"_a,
"weight"_a,
"eps"_a,
@ -49,7 +49,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"layer_norm",
&fast::layer_norm,
&mx::fast::layer_norm,
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
@ -79,7 +79,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rope",
&fast::rope,
&mx::fast::rope,
"a"_a,
"dims"_a,
nb::kw_only(),
@ -114,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"scaled_dot_product_attention",
&fast::scaled_dot_product_attention,
&mx::fast::scaled_dot_product_attention,
"q"_a,
"k"_a,
"v"_a,
@ -170,7 +170,7 @@ void init_fast(nb::module_& parent_module) {
const std::string& header,
bool ensure_row_contiguous,
bool atomic_outputs) {
auto kernel = fast::metal_kernel(
auto kernel = mx::fast::metal_kernel(
name,
input_names,
output_names,
@ -182,7 +182,7 @@ void init_fast(nb::module_& parent_module) {
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
@ -190,12 +190,12 @@ void init_fast(nb::module_& parent_module) {
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {}) {
std::vector<array> inputs;
mx::StreamOrDevice s = {}) {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, fast::TemplateArg>>
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
@ -206,8 +206,8 @@ void init_fast(nb::module_& parent_module) {
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(

View File

@ -9,24 +9,23 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def(
"fft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::fft(a, n.value(), axis, s);
return mx::fft::fft(a, n.value(), axis, s);
} else {
return fft::fft(a, axis, s);
return mx::fft::fft(a, axis, s);
}
},
"a"_a,
@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::ifft(a, n.value(), axis, s);
return mx::fft::ifft(a, n.value(), axis, s);
} else {
return fft::ifft(a, axis, s);
return mx::fft::ifft(a, axis, s);
}
},
"a"_a,
@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"fft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
return mx::fft::fftn(a, s);
}
},
"a"_a,
@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
return mx::fft::ifftn(a, s);
}
},
"a"_a,
@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"fftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::fftn(a, n.value(), axes.value(), s);
return mx::fft::fftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::fftn(a, axes.value(), s);
return mx::fft::fftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[fftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::fftn(a, s);
return mx::fft::fftn(a, s);
}
},
"a"_a,
@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"ifftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::ifftn(a, n.value(), axes.value(), s);
return mx::fft::ifftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::ifftn(a, axes.value(), s);
return mx::fft::ifftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::ifftn(a, s);
return mx::fft::ifftn(a, s);
}
},
"a"_a,
@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::rfft(a, n.value(), axis, s);
return mx::fft::rfft(a, n.value(), axis, s);
} else {
return fft::rfft(a, axis, s);
return mx::fft::rfft(a, axis, s);
}
},
"a"_a,
@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfft",
[](const array& a,
[](const mx::array& a,
const std::optional<int>& n,
int axis,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (n.has_value()) {
return fft::irfft(a, n.value(), axis, s);
return mx::fft::irfft(a, n.value(), axis, s);
} else {
return fft::irfft(a, axis, s);
return mx::fft::irfft(a, axis, s);
}
},
"a"_a,
@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
return mx::fft::rfftn(a, s);
}
},
"a"_a,
@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfft2",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
return mx::fft::irfftn(a, s);
}
},
"a"_a,
@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"rfftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::rfftn(a, n.value(), axes.value(), s);
return mx::fft::rfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::rfftn(a, axes.value(), s);
return mx::fft::rfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::rfftn(a, s);
return mx::fft::rfftn(a, s);
}
},
"a"_a,
@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
)pbdoc");
m.def(
"irfftn",
[](const array& a,
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
return fft::irfftn(a, n.value(), axes.value(), s);
return mx::fft::irfftn(a, n.value(), axes.value(), s);
} else if (axes.has_value()) {
return fft::irfftn(a, axes.value(), s);
return mx::fft::irfftn(a, axes.value(), s);
} else if (n.has_value()) {
throw std::invalid_argument(
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
} else {
return fft::irfftn(a, s);
return mx::fft::irfftn(a, s);
}
},
"a"_a,

View File

@ -43,20 +43,20 @@ void get_slice_params(
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
}
array get_int_index(nb::object idx, int axis_size) {
mx::array get_int_index(nb::object idx, int axis_size) {
int idx_ = nb::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32);
return mx::array(idx_, mx::uint32);
}
bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
nb::isinstance<nb::list>(obj);
nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
}
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -77,14 +77,14 @@ array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
return slice(src, starts, ends, strides);
}
array mlx_get_item_array(const array& src, const array& indices) {
mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
if (indices.dtype() == bool_) {
if (indices.dtype() == mx::bool_) {
throw std::invalid_argument("boolean indices are not yet supported");
}
@ -93,7 +93,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
return take(src, indices, 0);
}
array mlx_get_item_int(const array& src, const nb::int_& idx) {
mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -105,13 +105,13 @@ array mlx_get_item_int(const array& src, const nb::int_& idx) {
return take(src, get_int_index(idx, src.shape(0)), 0);
}
array mlx_gather_nd(
array src,
mx::array mlx_gather_nd(
mx::array src,
const std::vector<nb::object>& indices,
bool gather_first,
int& max_dims) {
max_dims = 0;
std::vector<array> gather_indices;
std::vector<mx::array> gather_indices;
std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0;
// gather all the arrays
@ -127,13 +127,13 @@ array mlx_gather_nd(
start = (start < 0) ? start + src.shape(i) : start;
end = (end < 0) ? end + src.shape(i) : end;
gather_indices.push_back(arange(start, end, stride, uint32));
gather_indices.push_back(arange(start, end, stride, mx::uint32));
num_slices++;
is_slice[i] = true;
} else if (nb::isinstance<nb::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (nb::isinstance<array>(idx)) {
auto arr = nb::cast<array>(idx);
} else if (nb::isinstance<mx::array>(idx)) {
auto arr = nb::cast<mx::array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr);
}
@ -144,7 +144,7 @@ array mlx_gather_nd(
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) {
Shape index_shape(max_dims + num_slices, 1);
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++;
@ -158,7 +158,7 @@ array mlx_gather_nd(
// reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) {
Shape index_shape(max_dims + num_slices, 1);
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
@ -241,7 +241,7 @@ auto mlx_expand_ellipsis(
return std::make_pair(non_none_indices, indices);
}
array mlx_get_item_nd(array src, const nb::tuple& entries) {
mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// No indices make this a noop
if (entries.size() == 0) {
return src;
@ -281,7 +281,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (nb::isinstance<array>(idx) || (nb::isinstance<nb::int_>(idx))) {
if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
if (have_array && have_non_array) {
gather_first = true;
break;
@ -294,7 +294,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
int n_arr = 0;
for (auto& idx : indices) {
n_arr += nb::isinstance<array>(idx);
n_arr += nb::isinstance<mx::array>(idx);
}
have_array &= n_arr > 0;
@ -304,7 +304,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
}
}
@ -340,7 +340,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
@ -426,11 +426,11 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
return src;
}
array mlx_get_item(const array& src, const nb::object& obj) {
mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<array>(obj)) {
return mlx_get_item_array(src, nb::cast<array>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (nb::isinstance<nb::tuple>(obj)) {
@ -448,10 +448,11 @@ array mlx_get_item(const array& src, const nb::object& obj) {
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src,
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_int(
const mx::array& src,
const nb::int_& idx,
const array& update) {
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
@ -473,10 +474,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
{0}};
}
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
const array& src,
const array& indices,
const array& update) {
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_array(
const mx::array& src,
const mx::array& indices,
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
@ -500,10 +502,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
return {{indices}, up, {0}};
}
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src,
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_slice(
const mx::array& src,
const nb::slice& in_slice,
const array& update) {
const mx::array& update) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -539,7 +542,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
auto up = reshape(update, up_shape);
// Build array to mark start of slice
auto idx = array({start}, {1}, uint32);
auto idx = mx::array({start}, {1}, mx::uint32);
// Get slice size
int slice_size = (end - start);
@ -551,20 +554,21 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
up = broadcast_to(up, up_shape_broadcast);
auto indices = std::vector<array>{idx};
auto indices = std::vector<mx::array>{idx};
auto axes = std::vector<int>{0};
return {indices, up, axes};
}
return mlx_scatter_args_array(
src, arange(start, end, stride, uint32), update);
src, arange(start, end, stride, mx::uint32), update);
}
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src,
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_nd(
const mx::array& src,
const nb::tuple& entries,
const array& update) {
const mx::array& update) {
// Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
@ -623,12 +627,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
num_simple_slices_post++;
}
} else if (nb::isinstance<array>(idx)) {
} else if (nb::isinstance<mx::array>(idx)) {
have_array = true;
if (have_array && have_non_array) {
arrays_first = true;
}
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
num_arrays++;
num_simple_slices_post = 0;
}
@ -643,7 +647,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
// Go over each index type and translate to the needed scatter args
std::vector<array> arr_indices;
std::vector<mx::array> arr_indices;
int slice_num = 0;
int array_num = 0;
int ax = 0;
@ -668,7 +672,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
// If it's a simple slice, we only need to add the start index
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
auto idx = array({start}, idx_shape, uint32);
auto idx = mx::array({start}, idx_shape, mx::uint32);
slice_shapes.push_back(end - start);
arr_indices.push_back(idx);
@ -677,7 +681,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
}
// Otherwise we expand the slice into indices using arange
else {
auto idx = arange(start, end, stride, uint32);
auto idx = arange(start, end, stride, mx::uint32);
auto loc = slice_num + (arrays_first ? max_dim : 0);
idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape));
@ -696,9 +700,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
} else if (pyidx.is_none()) {
// We only use the None's for bookeeping dimensions
slice_num++;
} else if (nb::isinstance<array>(pyidx)) {
} else if (nb::isinstance<mx::array>(pyidx)) {
ax++;
auto idx = nb::cast<array>(pyidx);
auto idx = nb::cast<mx::array>(pyidx);
std::vector<int> idx_shape(idx_ndim, 1);
// Place the arrays in the correct dimension
@ -748,16 +752,16 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
return {arr_indices, up, axes};
}
std::tuple<std::vector<array>, array, std::vector<int>>
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_compute_scatter_args(
const array& src,
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
@ -773,7 +777,7 @@ mlx_compute_scatter_args(
}
auto mlx_slice_update(
const array& src,
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
// Can't route to slice update if not slice or tuple
@ -784,7 +788,7 @@ auto mlx_slice_update(
if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
return std::make_pair(false, src);
}
}
@ -881,7 +885,10 @@ auto mlx_slice_update(
return std::make_pair(true, out);
}
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
void mlx_set_item(
mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [success, out] = mlx_slice_update(src, obj, v);
if (success) {
src.overwrite_descriptor(out);
@ -897,8 +904,8 @@ void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
}
}
array mlx_add_item(
const array& src,
mx::array mlx_add_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -909,8 +916,8 @@ array mlx_add_item(
}
}
array mlx_subtract_item(
const array& src,
mx::array mlx_subtract_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -921,8 +928,8 @@ array mlx_subtract_item(
}
}
array mlx_multiply_item(
const array& src,
mx::array mlx_multiply_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -933,8 +940,8 @@ array mlx_multiply_item(
}
}
array mlx_divide_item(
const array& src,
mx::array mlx_divide_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -945,8 +952,8 @@ array mlx_divide_item(
}
}
array mlx_maximum_item(
const array& src,
mx::array mlx_maximum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
@ -957,8 +964,8 @@ array mlx_maximum_item(
}
}
array mlx_minimum_item(
const array& src,
mx::array mlx_minimum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);

View File

@ -7,32 +7,35 @@
#include "mlx/array.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
array mlx_get_item(const array& src, const nb::object& obj);
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
array mlx_add_item(
const array& src,
mx::array mlx_get_item(const mx::array& src, const nb::object& obj);
void mlx_set_item(
mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_subtract_item(
const array& src,
mx::array mlx_add_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_multiply_item(
const array& src,
mx::array mlx_subtract_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_divide_item(
const array& src,
mx::array mlx_multiply_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_maximum_item(
const array& src,
mx::array mlx_divide_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_minimum_item(
const array& src,
mx::array mlx_maximum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);
mx::array mlx_minimum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v);

View File

@ -10,15 +10,13 @@
#include "mlx/linalg.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
namespace {
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
const auto result = svd(a, s);
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
} // namespace
@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
m.def(
"norm",
[](const array& a,
[](const mx::array& a,
const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims,
const StreamOrDevice stream) {
const mx::StreamOrDevice stream) {
std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv};
@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
}
if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream);
return mx::linalg::norm(a, axis, keepdims, stream);
} else {
if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream);
return mx::linalg::norm(a, *pv, axis, keepdims, stream);
}
double ord;
if (auto pv = std::get_if<int>(&ord_); pv) {
@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
} else {
ord = std::get<double>(ord_);
}
return norm(a, ord, axis, keepdims, stream);
return mx::linalg::norm(a, ord, axis, keepdims, stream);
}
},
nb::arg(),
@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"qr",
&qr,
&mx::linalg::qr,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"inv",
&inv,
&mx::linalg::inv,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"tri_inv",
&tri_inv,
&mx::linalg::tri_inv,
"a"_a,
"upper"_a,
nb::kw_only(),
@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cholesky",
&cholesky,
&mx::linalg::cholesky,
"a"_a,
"upper"_a = false,
nb::kw_only(),
@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cholesky_inv",
&cholesky_inv,
&mx::linalg::cholesky_inv,
"a"_a,
"upper"_a = false,
nb::kw_only(),
@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"pinv",
&pinv,
&mx::linalg::pinv,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"cross",
&cross,
&mx::linalg::cross,
"a"_a,
"b"_a,
"axis"_a = -1,
@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"eigvalsh",
&eigvalsh,
&mx::linalg::eigvalsh,
"a"_a,
"UPLO"_a = "L",
nb::kw_only(),
@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"eigh",
[](const array& a, const std::string UPLO, StreamOrDevice s) {
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
// TODO avoid cast?
auto result = eigh(a, UPLO, s);
auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},
"a"_a,

View File

@ -14,9 +14,9 @@
#include "python/src/load.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
@ -86,7 +86,7 @@ class ZipFileWrapper {
// Loading
///////////////////////////////////////////////////////////////////////////////
class PyFileReader : public io::Reader {
class PyFileReader : public mx::io::Reader {
public:
PyFileReader(nb::object file)
: pyistream_(file),
@ -168,14 +168,14 @@ class PyFileReader : public io::Reader {
};
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, mx::array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
return mx::load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string");
}
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
return mx::load_gguf(nb::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
nb::object file,
StreamOrDevice s) {
mx::StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file);
nb::module_ zipfile = nb::module_::import_("zipfile");
@ -209,7 +209,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
"opened with zipfile.ZipFile");
}
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
std::unordered_map<std::string, mx::array> array_dict;
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
@ -218,7 +218,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
nb::object sub_file = zipfile_object.open(st);
// Create array from python file stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);
// Remove .npy from file if it is there
auto key = st;
@ -240,12 +240,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict;
}
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s);
return mx::load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
arr.eval();
@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper(
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
mx::StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (nb::isinstance<nb::str>(file)) {
@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper(
// Saving
///////////////////////////////////////////////////////////////////////////////
class PyFileWriter : public io::Writer {
class PyFileWriter : public mx::io::Writer {
public:
PyFileWriter(nb::object file)
: pyostream_(file),
@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer {
nb::object tell_func_;
};
void mlx_save_helper(nb::object file, array a) {
void mlx_save_helper(nb::object file, mx::array a) {
if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a);
mx::save(nb::cast<std::string>(file), a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release gil;
save(writer, a);
mx::save(writer, a);
}
return;
@ -419,8 +419,9 @@ void mlx_savez_helper(
}
// Collect args and kwargs
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = nb::cast<std::vector<array>>(args);
auto arrays_dict =
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
@ -447,7 +448,7 @@ void mlx_savez_helper(
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
nb::gil_scoped_release nogil;
save(writer, a);
mx::save(writer, a);
}
}
@ -470,17 +471,18 @@ void mlx_save_safetensor_helper(
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
if (nb::isinstance<nb::str>(file)) {
{
nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_safetensors(
nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map);
mx::save_safetensors(writer, arrays_map, metadata_map);
}
} else {
throw std::invalid_argument(
@ -492,19 +494,20 @@ void mlx_save_gguf_helper(
nb::object file,
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (m) {
auto metadata_map =
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
m.value());
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else {
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map);
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
}
}
} else {

View File

@ -14,22 +14,24 @@
#include <variant>
#include "mlx/io.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
using LoadOutputTypes = std::variant<
array,
std::unordered_map<std::string, array>,
SafetensorsLoad,
GGUFLoad>;
mx::array,
std::unordered_map<std::string, mx::array>,
mx::SafetensorsLoad,
mx::GGUFLoad>;
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
mx::SafetensorsLoad mlx_load_safetensor_helper(
nb::object file,
mx::StreamOrDevice s);
void mlx_save_safetensor_helper(
nb::object file,
nb::dict d,
std::optional<nb::dict> m);
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s);
void mlx_save_gguf_helper(
nb::object file,
@ -40,8 +42,8 @@ LoadOutputTypes mlx_load_helper(
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s);
void mlx_save_helper(nb::object file, array a);
mx::StreamOrDevice s);
void mlx_save_helper(nb::object file, mx::array a);
void mlx_savez_helper(
nb::object file,
nb::args args,

View File

@ -8,22 +8,21 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def(
"is_available",
&metal::is_available,
&mx::metal::is_available,
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def(
"get_active_memory",
&metal::get_active_memory,
&mx::metal::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
@ -32,7 +31,7 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"get_peak_memory",
&metal::get_peak_memory,
&mx::metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
@ -41,13 +40,13 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"reset_peak_memory",
&metal::reset_peak_memory,
&mx::metal::reset_peak_memory,
R"pbdoc(
Reset the peak memory to zero.
)pbdoc");
metal.def(
"get_cache_memory",
&metal::get_cache_memory,
&mx::metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
@ -56,7 +55,7 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"set_memory_limit",
&metal::set_memory_limit,
&mx::metal::set_memory_limit,
"limit"_a,
nb::kw_only(),
"relaxed"_a = true,
@ -81,7 +80,7 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"set_cache_limit",
&metal::set_cache_limit,
&mx::metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
@ -101,7 +100,7 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"set_wired_limit",
&metal::set_wired_limit,
&mx::metal::set_wired_limit,
"limit"_a,
R"pbdoc(
Set the wired size limit.
@ -133,7 +132,7 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"clear_cache",
&metal::clear_cache,
&mx::metal::clear_cache,
R"pbdoc(
Clear the memory cache.
@ -142,7 +141,7 @@ void init_metal(nb::module_& m) {
metal.def(
"start_capture",
&metal::start_capture,
&mx::metal::start_capture,
"path"_a,
R"pbdoc(
Start a Metal capture.
@ -153,13 +152,13 @@ void init_metal(nb::module_& m) {
)pbdoc");
metal.def(
"stop_capture",
&metal::stop_capture,
&mx::metal::stop_capture,
R"pbdoc(
Stop a Metal capture.
)pbdoc");
metal.def(
"device_info",
&metal::device_info,
&mx::metal::device_info,
R"pbdoc(
Get information about the GPU device and system settings.

File diff suppressed because it is too large Load Diff

View File

@ -12,23 +12,22 @@
#include "mlx/ops.h"
#include "mlx/random.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
state_.append(mx::random::key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
state_[0] = mx::random::key(seed);
}
array next() {
auto out = split(nb::cast<array>(state_[0]));
mx::array next() {
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
state_[0] = out.first;
return out.second;
}
@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"key",
&key,
&mx::random::key,
"seed"_a,
R"pbdoc(
Get a PRNG key from a seed.
@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"split",
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
&mx::random::split),
"key"_a,
"num"_a = 2,
"stream"_a = nb::none(),
@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform(
return mx::random::uniform(
to_array(low),
to_array(high),
shape,
type.value_or(float32),
type.value_or(mx::float32),
key,
s);
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
m.def(
"normal",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
std::optional<mx::Dtype> type,
float loc,
float scale,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s);
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = nb::none(),
@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"multivariate_normal",
[](const array& mean,
const array& cov,
[](const mx::array& mean,
const mx::array& cov,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return multivariate_normal(
mean, cov, shape, type.value_or(float32), key, s);
return mx::random::multivariate_normal(
mean, cov, shape, type.value_or(mx::float32), key, s);
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
return mx::random::randint(
to_array(low),
to_array(high),
shape,
type.value_or(mx::int32),
key,
s);
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a.none() = int32,
"dtype"_a.none() = mx::int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
return mx::random::bernoulli(p, shape.value(), key, s);
} else {
return bernoulli(p, key, s);
return mx::random::bernoulli(p, key, s);
}
},
"p"_a = 0.5,
@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
auto t = type.value_or(mx::float32);
if (shape_.has_value()) {
return truncated_normal(lower, upper, shape_.value(), t, key, s);
return mx::random::truncated_normal(
lower, upper, shape_.value(), t, key, s);
} else {
return truncated_normal(lower, upper, t, key, s);
return mx::random::truncated_normal(lower, upper, t, key, s);
}
},
"lower"_a,
"upper"_a,
"shape"_a = nb::none(),
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
m.def(
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key_,
StreamOrDevice s) {
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s);
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"categorical",
[](const array& logits,
[](const mx::array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
} else if (shape.has_value()) {
return categorical(logits, axis, shape.value(), key, s);
return mx::random::categorical(logits, axis, shape.value(), key, s);
} else if (num_samples.has_value()) {
return categorical(logits, axis, num_samples.value(), key, s);
return mx::random::categorical(
logits, axis, num_samples.value(), key, s);
} else {
return categorical(logits, axis, key, s);
return mx::random::categorical(logits, axis, key, s);
}
},
"logits"_a,
@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
m.def(
"laplace",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
std::optional<mx::Dtype> type,
float loc,
float scale,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return laplace(shape, type.value_or(float32), loc, scale, key, s);
return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = nb::none(),
@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"permuation",
[](const std::variant<nb::int_, array>& x,
[](const std::variant<nb::int_, mx::array>& x,
int axis,
const std::optional<array>& key_,
StreamOrDevice s) {
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (auto pv = std::get_if<nb::int_>(&x); pv) {
return permutation(nb::cast<int>(*pv), key, s);
return mx::random::permutation(nb::cast<int>(*pv), key, s);
} else {
return permutation(std::get<array>(x), axis, key, s);
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},

View File

@ -10,14 +10,14 @@
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
// Create the StreamContext on enter and delete on exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
@ -26,7 +26,7 @@ class PyStreamContext {
}
void enter() {
_inner = new StreamContext(_s);
_inner = new mx::StreamContext(_s);
}
void exit() {
@ -37,39 +37,40 @@ class PyStreamContext {
}
private:
StreamOrDevice _s;
StreamContext* _inner;
mx::StreamOrDevice _s;
mx::StreamContext* _inner;
};
void init_stream(nb::module_& m) {
nb::class_<Stream>(
nb::class_<mx::Stream>(
m,
"Stream",
R"pbdoc(
A stream for running operations on a given device.
)pbdoc")
.def_ro("device", &Stream::device)
.def_ro("device", &mx::Stream::device)
.def(
"__repr__",
[](const Stream& s) {
[](const mx::Stream& s) {
std::ostringstream os;
os << s;
return os.str();
})
.def("__eq__", [](const Stream& s, const nb::object& other) {
return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
.def("__eq__", [](const mx::Stream& s, const nb::object& other) {
return nb::isinstance<mx::Stream>(other) &&
s == nb::cast<mx::Stream>(other);
});
nb::implicitly_convertible<Device::DeviceType, Device>();
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
m.def(
"default_stream",
&default_stream,
&mx::default_stream,
"device"_a,
R"pbdoc(Get the device's default stream.)pbdoc");
m.def(
"set_default_stream",
&set_default_stream,
&mx::set_default_stream,
"stream"_a,
R"pbdoc(
Set the default stream.
@ -82,7 +83,7 @@ void init_stream(nb::module_& m) {
)pbdoc");
m.def(
"new_stream",
&new_stream,
&mx::new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
@ -94,7 +95,7 @@ void init_stream(nb::module_& m) {
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(nb::init<StreamOrDevice>(), "s"_a)
.def(nb::init<mx::StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
@ -107,7 +108,7 @@ void init_stream(nb::module_& m) {
"traceback"_a = nb::none());
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
[](mx::StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
@ -131,8 +132,8 @@ void init_stream(nb::module_& m) {
)pbdoc");
m.def(
"synchronize",
[](const std::optional<Stream>& s) {
s ? synchronize(s.value()) : synchronize();
[](const std::optional<mx::Stream>& s) {
s ? mx::synchronize(s.value()) : mx::synchronize();
},
"stream"_a = nb::none(),
R"pbdoc(

View File

@ -20,9 +20,12 @@
#include "mlx/utils.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
// Needed for printing shapes and strides.
using mx::operator<<;
using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
@ -108,7 +111,7 @@ auto py_value_and_grad(
}
// Collect the arrays
std::vector<array> arrays;
std::vector<mx::array> arrays;
std::vector<int> counts(1, 0);
for (auto i : argnums) {
auto argsi = tree_flatten(args[i]);
@ -127,7 +130,7 @@ auto py_value_and_grad(
// value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
nb::object py_value_out;
auto value_and_grads = value_and_grad(
auto value_and_grads = mx::value_and_grad(
[&fun,
&args,
&kwargs,
@ -136,7 +139,7 @@ auto py_value_and_grad(
&counts,
&py_value_out,
&error_msg_tag,
scalar_func_only](const std::vector<array>& a) {
scalar_func_only](const std::vector<mx::array>& a) {
// Copy the arguments
nb::list args_cpy;
nb::kwargs kwargs_cpy = nb::kwargs();
@ -165,7 +168,7 @@ auto py_value_and_grad(
py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function
if (!nb::isinstance<array>(py_value_out)) {
if (!nb::isinstance<mx::array>(py_value_out)) {
if (scalar_func_only) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
@ -193,7 +196,7 @@ auto py_value_and_grad(
<< "we got an empty tuple.";
throw std::invalid_argument(msg.str());
}
if (!nb::isinstance<array>(ret[0])) {
if (!nb::isinstance<mx::array>(ret[0])) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
@ -275,12 +278,12 @@ auto py_vmap(
{tree, axes},
[&flat_axes, &encountered_tuple, output_axes](
const std::vector<nb::object>& inputs) {
if (nb::isinstance<array>(inputs[0])) {
if (nb::isinstance<mx::array>(inputs[0])) {
if (inputs[1].is_none()) {
flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(inputs[1])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
const array& x = nb::cast<array>(inputs[0]);
const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
@ -297,7 +300,7 @@ auto py_vmap(
auto l = nb::cast<nb::tuple>(inputs[1]);
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
const array& x = nb::cast<array>(inputs[0]);
const mx::array& x = nb::cast<mx::array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
@ -323,7 +326,7 @@ auto py_vmap(
"[vmap] The arguments should contain only arrays");
}
});
if (encountered_tuple && !nb::isinstance<array>(tree)) {
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
return flat_axes;
@ -339,7 +342,7 @@ auto py_vmap(
nb::object py_outputs;
auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
// Call the python function
py_outputs = fun(*tree_unflatten(args, a));
@ -348,12 +351,12 @@ auto py_vmap(
};
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
// Perform the vmap
auto outputs = detail::vmap_replace(
auto outputs = mx::detail::vmap_replace(
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
// Put the outputs back in the container
@ -401,7 +404,7 @@ struct PyCompiledFun {
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
// Flat array inputs
std::vector<array> inputs;
std::vector<mx::array> inputs;
// Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants;
@ -437,8 +440,8 @@ struct PyCompiledFun {
constants.push_back(nb::cast<int64_t>(r));
recurse(item.second);
}
} else if (nb::isinstance<array>(obj)) {
inputs.push_back(nb::cast<array>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
inputs.push_back(nb::cast<mx::array>(obj));
constants.push_back(array_identifier);
} else if (nb::isinstance<nb::str>(obj)) {
auto r = obj.attr("__hash__")();
@ -461,10 +464,10 @@ struct PyCompiledFun {
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
const std::vector<mx::array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
std::vector<mx::array> flat_in_captures;
std::vector<mx::array> trace_captures;
if (!captured_inputs.is_none()) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
@ -505,9 +508,9 @@ struct PyCompiledFun {
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!captured_outputs.is_none()) {
std::vector<array> captures(
std::vector<mx::array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures);
@ -526,7 +529,7 @@ struct PyCompiledFun {
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
mx::detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
@ -561,7 +564,7 @@ class PyCheckpointedFun {
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
auto args = nb::cast<nb::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
@ -579,7 +582,7 @@ class PyCheckpointedFun {
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint(
auto outputs = mx::checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs);
@ -660,12 +663,12 @@ class PyCustomFunction {
}
}
std::vector<array> operator()(const std::vector<array>& inputs) {
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>(
tree_unflatten_from_structure(input_structure_, inputs));
std::vector<array> outputs;
std::vector<mx::array> outputs;
std::tie(outputs, *output_structure_) =
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
return outputs;
@ -694,10 +697,10 @@ class PyCustomFunction {
}
}
std::vector<array> operator()(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) {
std::vector<mx::array> operator()(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<mx::array>& outputs) {
nb::gil_scoped_acquire gil;
auto new_inputs = nb::cast<nb::tuple>(
@ -734,9 +737,9 @@ class PyCustomFunction {
input_structure_.release().dec_ref();
}
std::vector<array> operator()(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> operator()(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) {
nb::gil_scoped_acquire gil;
@ -759,7 +762,7 @@ class PyCustomFunction {
int tangent_index = 0;
auto new_tangents =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
if (nb::isinstance<array>(element) &&
if (nb::isinstance<mx::array>(element) &&
have_tangents[array_index++]) {
return nb::cast(tangents[tangent_index++]);
} else {
@ -789,8 +792,8 @@ class PyCustomFunction {
input_structure_.release().dec_ref();
}
std::pair<std::vector<array>, std::vector<int>> operator()(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) {
nb::gil_scoped_acquire gil;
@ -807,7 +810,7 @@ class PyCustomFunction {
auto new_axes =
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
int axis = axes[arr_index++];
if (nb::isinstance<array>(element) && axis >= 0) {
if (nb::isinstance<mx::array>(element) && axis >= 0) {
return nb::cast(axis);
} else {
return nb::none();
@ -831,11 +834,11 @@ class PyCustomFunction {
"[custom vmap] Vmap function should return a tuple with 2 items.");
}
std::vector<array> outputs;
std::vector<mx::array> outputs;
std::vector<int> output_axes;
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
if (nb::isinstance<array>(objects[0])) {
outputs.push_back(nb::cast<array>(objects[0]));
if (nb::isinstance<mx::array>(objects[0])) {
outputs.push_back(nb::cast<mx::array>(objects[0]));
output_axes.push_back(
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
}
@ -852,7 +855,7 @@ class PyCustomFunction {
}
// Extract the inputs and their structure in capturable vars
std::vector<array> input_arrays;
std::vector<mx::array> input_arrays;
nb::object input_structure;
auto full_args = nb::make_tuple(args, kwargs);
std::tie(input_arrays, input_structure) =
@ -864,7 +867,7 @@ class PyCustomFunction {
// Make a function that calls fun_ in the forward pass and vjp_ in the
// backward pass. Then call it immediately and return the results.
auto f = custom_function(
auto f = mx::custom_function(
InnerFunction(fun_, input_structure, output_structure),
make_vjp_function(input_structure, output_structure),
make_jvp_function(input_structure),
@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
m.def(
"eval",
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
std::vector<mx::array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
eval(arrays);
@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
m.def(
"async_eval",
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
std::vector<mx::array> arrays = tree_flatten(args, false);
{
nb::gil_scoped_release nogil;
async_eval(arrays);
@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
m.def(
"jvp",
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents) {
auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
if (nb::isinstance<mx::array>(out)) {
return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else {
return nb::cast<std::vector<array>>(out);
return nb::cast<std::vector<mx::array>>(out);
}
};
return jvp(vfun, primals, tangents);
@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
m.def(
"vjp",
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents) {
auto vfun = [&fun](const std::vector<mx::array>& primals) {
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
if (nb::isinstance<mx::array>(out)) {
return std::vector<mx::array>{nb::cast<mx::array>(out)};
} else {
return nb::cast<std::vector<array>>(out);
return nb::cast<std::vector<mx::array>>(out);
}
};
return vjp(vfun, primals, cotangents);
@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args);
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
)pbdoc");
m.def(
"disable_compile",
&disable_compile,
&mx::disable_compile,
R"pbdoc(
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compile",
&enable_compile,
&mx::enable_compile,
R"pbdoc(
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() {
tree_cache().clear();
detail::compile_clear_cache();
mx::detail::compile_clear_cache();
}));
}

View File

@ -188,7 +188,7 @@ void tree_visit_update(
d[item.first] = recurse(item.second);
}
return nb::cast<nb::object>(d);
} else if (nb::isinstance<array>(subtree)) {
} else if (nb::isinstance<mx::array>(subtree)) {
return visitor(subtree);
} else {
return nb::cast<nb::object>(subtree);
@ -200,7 +200,7 @@ void tree_visit_update(
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(nb::object& tree, const std::vector<array>& values) {
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
nb::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
const std::vector<mx::array>& src,
const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](nb::handle node) {
auto arr = nb::cast<array>(node);
auto arr = nb::cast<mx::array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return nb::cast(it->second);
}
@ -224,12 +224,12 @@ void tree_replace(
});
}
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
nb::object tree_unflatten(
nb::object tree,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
if (nb::isinstance<mx::array>(obj)) {
return nb::cast(values[index++]);
} else {
return nb::cast<nb::object>(obj);
@ -265,16 +265,16 @@ nb::object structure_sentinel() {
return sentinel;
}
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
std::vector<mx::array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
if (nb::isinstance<mx::array>(obj)) {
flat_tree.push_back(nb::cast<mx::array>(obj));
return sentinel;
} else if (!strict) {
return nb::cast<nb::object>(obj);
@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](nb::handle obj) {

View File

@ -4,8 +4,8 @@
#include "mlx/array.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
void tree_visit(
const std::vector<nb::object>& trees,
@ -27,7 +27,7 @@ void tree_visit_update(
/**
* Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */
void tree_fill(nb::object& tree, const std::vector<array>& values);
void tree_fill(nb::object& tree, const std::vector<mx::array>& values);
/**
* Replace all the arrays from the src values with the dst values in the
@ -35,28 +35,28 @@ void tree_fill(nb::object& tree, const std::vector<array>& values);
*/
void tree_replace(
nb::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst);
const std::vector<mx::array>& src,
const std::vector<mx::array>& dst);
/**
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<array> tree_flatten(nb::object tree, bool strict = true);
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.
*/
nb::object tree_unflatten(
nb::object tree,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index = 0);
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict = true);
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<array>& values,
const std::vector<mx::array>& values,
int index = 0);

View File

@ -4,22 +4,24 @@
#include "mlx/ops.h"
#include "python/src/convert.h"
array to_array(
mx::array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype /* = std::nullopt */) {
std::optional<mx::Dtype> dtype /* = std::nullopt */) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32);
auto out_t = dtype.value_or(mx::int32);
// bool_ is an exception and is always promoted
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
return mx::array(
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32);
return array(
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
auto out_t = dtype.value_or(mx::float32);
return mx::array(
nb::cast<float>(*pv),
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else if (auto pv = std::get_if<array>(&v); pv) {
return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return *pv;
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
@ -30,7 +32,7 @@ array to_array(
}
}
std::pair<array, array> to_arrays(
std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a,
const ScalarOrArray& b) {
// Four cases:
@ -39,15 +41,15 @@ std::pair<array, array> to_arrays(
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<array>(x) ||
return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<nb::object>(x) &&
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
};
auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<array>(&x); px) {
if (auto px = std::get_if<mx::array>(&x); px) {
return *px;
} else {
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
}
};
@ -66,11 +68,11 @@ std::pair<array, array> to_arrays(
}
}
array to_array_with_accessor(nb::object obj) {
if (nb::isinstance<array>(obj)) {
return nb::cast<array>(obj);
mx::array to_array_with_accessor(nb::object obj) {
if (nb::isinstance<mx::array>(obj)) {
return nb::cast<mx::array>(obj);
} else if (nb::hasattr(obj, "__mlx_array__")) {
return nb::cast<array>(obj.attr("__mlx_array__")());
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()

View File

@ -12,17 +12,16 @@
#include "mlx/array.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
array,
mx::array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
@ -45,7 +44,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
@ -66,12 +65,12 @@ inline void throw_invalid_operation(
throw std::invalid_argument(msg.str());
}
array to_array(
mx::array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt);
std::optional<mx::Dtype> dtype = std::nullopt);
std::pair<array, array> to_arrays(
std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a,
const ScalarOrArray& b);
array to_array_with_accessor(nb::object obj);
mx::array to_array_with_accessor(nb::object obj);