mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 03:06:39 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
parent
f3dfa36a3a
commit
0bf19037ca
File diff suppressed because it is too large
Load Diff
@ -14,37 +14,37 @@
|
||||
#define Py_bf_releasebuffer 2
|
||||
#endif
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
std::string buffer_format(const array& a) {
|
||||
std::string buffer_format(const mx::array& a) {
|
||||
// https://docs.python.org/3.10/library/struct.html#format-characters
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
case mx::bool_:
|
||||
return "?";
|
||||
case uint8:
|
||||
case mx::uint8:
|
||||
return "B";
|
||||
case uint16:
|
||||
case mx::uint16:
|
||||
return "H";
|
||||
case uint32:
|
||||
case mx::uint32:
|
||||
return "I";
|
||||
case uint64:
|
||||
case mx::uint64:
|
||||
return "Q";
|
||||
case int8:
|
||||
case mx::int8:
|
||||
return "b";
|
||||
case int16:
|
||||
case mx::int16:
|
||||
return "h";
|
||||
case int32:
|
||||
case mx::int32:
|
||||
return "i";
|
||||
case int64:
|
||||
case mx::int64:
|
||||
return "q";
|
||||
case float16:
|
||||
case mx::float16:
|
||||
return "e";
|
||||
case float32:
|
||||
case mx::float32:
|
||||
return "f";
|
||||
case bfloat16:
|
||||
case mx::bfloat16:
|
||||
return "B";
|
||||
case complex64:
|
||||
case mx::complex64:
|
||||
return "Zf\0";
|
||||
default: {
|
||||
std::ostringstream os;
|
||||
@ -84,7 +84,7 @@ struct buffer_info {
|
||||
|
||||
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
||||
std::memset(view, 0, sizeof(Py_buffer));
|
||||
auto a = nb::cast<array>(nb::handle(obj));
|
||||
auto a = nb::cast<mx::array>(nb::handle(obj));
|
||||
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
|
@ -16,7 +16,7 @@ enum PyScalarT {
|
||||
|
||||
namespace nanobind {
|
||||
template <>
|
||||
struct ndarray_traits<float16_t> {
|
||||
struct ndarray_traits<mx::float16_t> {
|
||||
static constexpr bool is_complex = false;
|
||||
static constexpr bool is_float = true;
|
||||
static constexpr bool is_bool = false;
|
||||
@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
array nd_array_to_mlx_contiguous(
|
||||
mx::array nd_array_to_mlx_contiguous(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
const Shape& shape,
|
||||
Dtype dtype) {
|
||||
const mx::Shape& shape,
|
||||
mx::Dtype dtype) {
|
||||
// Make a copy of the numpy buffer
|
||||
// Get buffer ptr pass to array constructor
|
||||
auto data_ptr = nd_array.data();
|
||||
return array(static_cast<const T*>(data_ptr), shape, dtype);
|
||||
return mx::array(static_cast<const T*>(data_ptr), shape, dtype);
|
||||
}
|
||||
|
||||
array nd_array_to_mlx(
|
||||
mx::array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype) {
|
||||
std::optional<mx::Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
Shape shape;
|
||||
mx::Shape shape;
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
@ -59,49 +59,49 @@ array nd_array_to_mlx(
|
||||
// Copy data and make array
|
||||
if (type == nb::dtype<bool>()) {
|
||||
return nd_array_to_mlx_contiguous<bool>(
|
||||
nd_array, shape, dtype.value_or(bool_));
|
||||
nd_array, shape, dtype.value_or(mx::bool_));
|
||||
} else if (type == nb::dtype<uint8_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint8_t>(
|
||||
nd_array, shape, dtype.value_or(uint8));
|
||||
nd_array, shape, dtype.value_or(mx::uint8));
|
||||
} else if (type == nb::dtype<uint16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint16_t>(
|
||||
nd_array, shape, dtype.value_or(uint16));
|
||||
nd_array, shape, dtype.value_or(mx::uint16));
|
||||
} else if (type == nb::dtype<uint32_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint32_t>(
|
||||
nd_array, shape, dtype.value_or(uint32));
|
||||
nd_array, shape, dtype.value_or(mx::uint32));
|
||||
} else if (type == nb::dtype<uint64_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint64_t>(
|
||||
nd_array, shape, dtype.value_or(uint64));
|
||||
nd_array, shape, dtype.value_or(mx::uint64));
|
||||
} else if (type == nb::dtype<int8_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int8_t>(
|
||||
nd_array, shape, dtype.value_or(int8));
|
||||
nd_array, shape, dtype.value_or(mx::int8));
|
||||
} else if (type == nb::dtype<int16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int16_t>(
|
||||
nd_array, shape, dtype.value_or(int16));
|
||||
nd_array, shape, dtype.value_or(mx::int16));
|
||||
} else if (type == nb::dtype<int32_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int32_t>(
|
||||
nd_array, shape, dtype.value_or(int32));
|
||||
nd_array, shape, dtype.value_or(mx::int32));
|
||||
} else if (type == nb::dtype<int64_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int64_t>(
|
||||
nd_array, shape, dtype.value_or(int64));
|
||||
} else if (type == nb::dtype<float16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<float16_t>(
|
||||
nd_array, shape, dtype.value_or(float16));
|
||||
nd_array, shape, dtype.value_or(mx::int64));
|
||||
} else if (type == nb::dtype<mx::float16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<mx::float16_t>(
|
||||
nd_array, shape, dtype.value_or(mx::float16));
|
||||
} else if (type == nb::bfloat16) {
|
||||
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
||||
nd_array, shape, dtype.value_or(bfloat16));
|
||||
return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
|
||||
nd_array, shape, dtype.value_or(mx::bfloat16));
|
||||
} else if (type == nb::dtype<float>()) {
|
||||
return nd_array_to_mlx_contiguous<float>(
|
||||
nd_array, shape, dtype.value_or(float32));
|
||||
nd_array, shape, dtype.value_or(mx::float32));
|
||||
} else if (type == nb::dtype<double>()) {
|
||||
return nd_array_to_mlx_contiguous<double>(
|
||||
nd_array, shape, dtype.value_or(float32));
|
||||
nd_array, shape, dtype.value_or(mx::float32));
|
||||
} else if (type == nb::dtype<std::complex<float>>()) {
|
||||
return nd_array_to_mlx_contiguous<complex64_t>(
|
||||
nd_array, shape, dtype.value_or(complex64));
|
||||
return nd_array_to_mlx_contiguous<mx::complex64_t>(
|
||||
nd_array, shape, dtype.value_or(mx::complex64));
|
||||
} else if (type == nb::dtype<std::complex<double>>()) {
|
||||
return nd_array_to_mlx_contiguous<complex128_t>(
|
||||
nd_array, shape, dtype.value_or(complex64));
|
||||
return nd_array_to_mlx_contiguous<mx::complex128_t>(
|
||||
nd_array, shape, dtype.value_or(mx::complex64));
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
|
||||
}
|
||||
@ -109,7 +109,7 @@ array nd_array_to_mlx(
|
||||
|
||||
template <typename T, typename... NDParams>
|
||||
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
array a,
|
||||
mx::array a,
|
||||
std::optional<nb::dlpack::dtype> t = {}) {
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
@ -126,48 +126,48 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
}
|
||||
|
||||
template <typename... NDParams>
|
||||
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
||||
nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
case mx::bool_:
|
||||
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
||||
case uint8:
|
||||
case mx::uint8:
|
||||
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
||||
case uint16:
|
||||
case mx::uint16:
|
||||
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
||||
case uint32:
|
||||
case mx::uint32:
|
||||
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
||||
case uint64:
|
||||
case mx::uint64:
|
||||
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
||||
case int8:
|
||||
case mx::int8:
|
||||
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
||||
case int16:
|
||||
case mx::int16:
|
||||
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
||||
case int32:
|
||||
case mx::int32:
|
||||
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
||||
case int64:
|
||||
case mx::int64:
|
||||
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
||||
case float16:
|
||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||
case bfloat16:
|
||||
case mx::float16:
|
||||
return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);
|
||||
case mx::bfloat16:
|
||||
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||
case float32:
|
||||
case mx::float32:
|
||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||
case complex64:
|
||||
case mx::complex64:
|
||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||
default:
|
||||
throw nb::type_error("type cannot be converted to NumPy.");
|
||||
}
|
||||
}
|
||||
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {
|
||||
return mlx_to_nd_array<nb::numpy>(a);
|
||||
}
|
||||
|
||||
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
||||
nb::ndarray<> mlx_to_dlpack(const mx::array& a) {
|
||||
return mlx_to_nd_array<>(a);
|
||||
}
|
||||
|
||||
nb::object to_scalar(array& a) {
|
||||
nb::object to_scalar(mx::array& a) {
|
||||
if (a.size() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"[convert] Only length-1 arrays can be converted to Python scalars.");
|
||||
@ -177,31 +177,31 @@ nb::object to_scalar(array& a) {
|
||||
a.eval();
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
case mx::bool_:
|
||||
return nb::cast(a.item<bool>());
|
||||
case uint8:
|
||||
case mx::uint8:
|
||||
return nb::cast(a.item<uint8_t>());
|
||||
case uint16:
|
||||
case mx::uint16:
|
||||
return nb::cast(a.item<uint16_t>());
|
||||
case uint32:
|
||||
case mx::uint32:
|
||||
return nb::cast(a.item<uint32_t>());
|
||||
case uint64:
|
||||
case mx::uint64:
|
||||
return nb::cast(a.item<uint64_t>());
|
||||
case int8:
|
||||
case mx::int8:
|
||||
return nb::cast(a.item<int8_t>());
|
||||
case int16:
|
||||
case mx::int16:
|
||||
return nb::cast(a.item<int16_t>());
|
||||
case int32:
|
||||
case mx::int32:
|
||||
return nb::cast(a.item<int32_t>());
|
||||
case int64:
|
||||
case mx::int64:
|
||||
return nb::cast(a.item<int64_t>());
|
||||
case float16:
|
||||
return nb::cast(static_cast<float>(a.item<float16_t>()));
|
||||
case float32:
|
||||
case mx::float16:
|
||||
return nb::cast(static_cast<float>(a.item<mx::float16_t>()));
|
||||
case mx::float32:
|
||||
return nb::cast(a.item<float>());
|
||||
case bfloat16:
|
||||
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||
case complex64:
|
||||
case mx::bfloat16:
|
||||
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
|
||||
case mx::complex64:
|
||||
return nb::cast(a.item<std::complex<float>>());
|
||||
default:
|
||||
throw nb::type_error("type cannot be converted to Python scalar.");
|
||||
@ -209,7 +209,7 @@ nb::object to_scalar(array& a) {
|
||||
}
|
||||
|
||||
template <typename T, typename U = T>
|
||||
nb::list to_list(array& a, size_t index, int dim) {
|
||||
nb::list to_list(mx::array& a, size_t index, int dim) {
|
||||
nb::list pl;
|
||||
auto stride = a.strides()[dim];
|
||||
for (int i = 0; i < a.shape(dim); ++i) {
|
||||
@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) {
|
||||
return pl;
|
||||
}
|
||||
|
||||
nb::object tolist(array& a) {
|
||||
nb::object tolist(mx::array& a) {
|
||||
if (a.ndim() == 0) {
|
||||
return to_scalar(a);
|
||||
}
|
||||
@ -232,31 +232,31 @@ nb::object tolist(array& a) {
|
||||
a.eval();
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
case mx::bool_:
|
||||
return to_list<bool>(a, 0, 0);
|
||||
case uint8:
|
||||
case mx::uint8:
|
||||
return to_list<uint8_t>(a, 0, 0);
|
||||
case uint16:
|
||||
case mx::uint16:
|
||||
return to_list<uint16_t>(a, 0, 0);
|
||||
case uint32:
|
||||
case mx::uint32:
|
||||
return to_list<uint32_t>(a, 0, 0);
|
||||
case uint64:
|
||||
case mx::uint64:
|
||||
return to_list<uint64_t>(a, 0, 0);
|
||||
case int8:
|
||||
case mx::int8:
|
||||
return to_list<int8_t>(a, 0, 0);
|
||||
case int16:
|
||||
case mx::int16:
|
||||
return to_list<int16_t>(a, 0, 0);
|
||||
case int32:
|
||||
case mx::int32:
|
||||
return to_list<int32_t>(a, 0, 0);
|
||||
case int64:
|
||||
case mx::int64:
|
||||
return to_list<int64_t>(a, 0, 0);
|
||||
case float16:
|
||||
return to_list<float16_t, float>(a, 0, 0);
|
||||
case float32:
|
||||
case mx::float16:
|
||||
return to_list<mx::float16_t, float>(a, 0, 0);
|
||||
case mx::float32:
|
||||
return to_list<float>(a, 0, 0);
|
||||
case bfloat16:
|
||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||
case complex64:
|
||||
case mx::bfloat16:
|
||||
return to_list<mx::bfloat16_t, float>(a, 0, 0);
|
||||
case mx::complex64:
|
||||
return to_list<std::complex<float>>(a, 0, 0);
|
||||
default:
|
||||
throw nb::type_error("data type cannot be converted to Python list.");
|
||||
@ -279,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
|
||||
template <typename T>
|
||||
PyScalarT validate_shape(
|
||||
T list,
|
||||
const Shape& shape,
|
||||
const mx::Shape& shape,
|
||||
int idx,
|
||||
bool& all_python_primitive_elements) {
|
||||
if (idx >= shape.size()) {
|
||||
@ -307,9 +307,9 @@ PyScalarT validate_shape(
|
||||
shape,
|
||||
idx + 1,
|
||||
all_python_primitive_elements);
|
||||
} else if (nb::isinstance<array>(l)) {
|
||||
} else if (nb::isinstance<mx::array>(l)) {
|
||||
all_python_primitive_elements = false;
|
||||
auto arr = nb::cast<array>(l);
|
||||
auto arr = nb::cast<mx::array>(l);
|
||||
if (arr.ndim() + idx + 1 == shape.size() &&
|
||||
std::equal(
|
||||
arr.shape().cbegin(),
|
||||
@ -347,7 +347,7 @@ PyScalarT validate_shape(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_shape(T list, Shape& shape) {
|
||||
void get_shape(T list, mx::Shape& shape) {
|
||||
shape.push_back(check_shape_dim(nb::len(list)));
|
||||
if (shape.back() > 0) {
|
||||
auto l = list.begin();
|
||||
@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) {
|
||||
return get_shape(nb::cast<nb::list>(*l), shape);
|
||||
} else if (nb::isinstance<nb::tuple>(*l)) {
|
||||
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
||||
} else if (nb::isinstance<array>(*l)) {
|
||||
auto arr = nb::cast<array>(*l);
|
||||
} else if (nb::isinstance<mx::array>(*l)) {
|
||||
auto arr = nb::cast<mx::array>(*l);
|
||||
for (int i = 0; i < arr.ndim(); i++) {
|
||||
shape.push_back(arr.shape(i));
|
||||
}
|
||||
@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
array array_from_list_impl(
|
||||
mx::array array_from_list_impl(
|
||||
T pl,
|
||||
const PyScalarT& inferred_type,
|
||||
std::optional<Dtype> specified_type,
|
||||
const Shape& shape) {
|
||||
std::optional<mx::Dtype> specified_type,
|
||||
const mx::Shape& shape) {
|
||||
// Make the array
|
||||
switch (inferred_type) {
|
||||
case pybool: {
|
||||
std::vector<bool> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
||||
return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));
|
||||
}
|
||||
case pyint: {
|
||||
auto dtype = specified_type.value_or(int32);
|
||||
if (dtype == int64) {
|
||||
auto dtype = specified_type.value_or(mx::int32);
|
||||
if (dtype == mx::int64) {
|
||||
std::vector<int64_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == uint64) {
|
||||
return mx::array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == mx::uint64) {
|
||||
std::vector<uint64_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == uint32) {
|
||||
return mx::array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == mx::uint32) {
|
||||
std::vector<uint32_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (issubdtype(dtype, inexact)) {
|
||||
return mx::array(vals.begin(), shape, dtype);
|
||||
} else if (mx::issubdtype(dtype, mx::inexact)) {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
return mx::array(vals.begin(), shape, dtype);
|
||||
} else {
|
||||
std::vector<int> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
return mx::array(vals.begin(), shape, dtype);
|
||||
}
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(float32));
|
||||
return mx::array(
|
||||
vals.begin(), shape, specified_type.value_or(mx::float32));
|
||||
}
|
||||
case pycomplex: {
|
||||
std::vector<std::complex<float>> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(
|
||||
reinterpret_cast<complex64_t*>(vals.data()),
|
||||
return mx::array(
|
||||
reinterpret_cast<mx::complex64_t*>(vals.data()),
|
||||
shape,
|
||||
specified_type.value_or(complex64));
|
||||
specified_type.value_or(mx::complex64));
|
||||
}
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
@ -425,9 +426,9 @@ array array_from_list_impl(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
||||
mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
|
||||
// Compute the shape
|
||||
Shape shape;
|
||||
mx::Shape shape;
|
||||
get_shape(pl, shape);
|
||||
|
||||
// Validate the shape and type
|
||||
@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
||||
}
|
||||
|
||||
// `pl` contains mlx arrays
|
||||
std::vector<array> arrays;
|
||||
std::vector<mx::array> arrays;
|
||||
for (auto l : pl) {
|
||||
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
||||
}
|
||||
return stack(arrays);
|
||||
return mx::stack(arrays);
|
||||
}
|
||||
|
||||
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
|
||||
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {
|
||||
return array_from_list_impl(pl, dtype);
|
||||
}
|
||||
|
||||
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
|
||||
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
|
||||
return array_from_list_impl(pl, dtype);
|
||||
}
|
||||
|
||||
array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
||||
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return array(nb::cast<bool>(*pv), t.value_or(bool_));
|
||||
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
return array(nb::cast<int>(*pv), t.value_or(int32));
|
||||
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
return array(nb::cast<float>(*pv), t.value_or(float32));
|
||||
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||
return mx::array(
|
||||
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
|
||||
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
||||
@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||
pv) {
|
||||
return nd_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||
return astype(*pv, t.value_or((*pv).dtype()));
|
||||
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||
return mx::astype(*pv, t.value_or((*pv).dtype()));
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
||||
return astype(arr, t.value_or(arr.dtype()));
|
||||
return mx::astype(arr, t.value_or(arr.dtype()));
|
||||
}
|
||||
}
|
||||
|
@ -9,15 +9,15 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
using ArrayInitType = std::variant<
|
||||
nb::bool_,
|
||||
nb::int_,
|
||||
nb::float_,
|
||||
// Must be above ndarray
|
||||
array,
|
||||
mx::array,
|
||||
// Must be above complex
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||
std::complex<float>,
|
||||
@ -25,17 +25,17 @@ using ArrayInitType = std::variant<
|
||||
nb::tuple,
|
||||
nb::object>;
|
||||
|
||||
array nd_array_to_mlx(
|
||||
mx::array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype);
|
||||
std::optional<mx::Dtype> dtype);
|
||||
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
||||
nb::ndarray<> mlx_to_dlpack(const array& a);
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
|
||||
nb::ndarray<> mlx_to_dlpack(const mx::array& a);
|
||||
|
||||
nb::object to_scalar(array& a);
|
||||
nb::object to_scalar(mx::array& a);
|
||||
|
||||
nb::object tolist(array& a);
|
||||
nb::object tolist(mx::array& a);
|
||||
|
||||
array create_array(ArrayInitType v, std::optional<Dtype> t);
|
||||
array array_from_list(nb::list pl, std::optional<Dtype> dtype);
|
||||
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype);
|
||||
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
|
||||
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
|
||||
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);
|
||||
|
@ -8,51 +8,54 @@
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(nb::module_& m) {
|
||||
auto device_class = nb::class_<Device>(
|
||||
auto device_class = nb::class_<mx::Device>(
|
||||
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
||||
nb::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", mx::Device::DeviceType::cpu)
|
||||
.value("gpu", mx::Device::DeviceType::gpu)
|
||||
.export_values()
|
||||
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
|
||||
if (!nb::isinstance<Device>(other) &&
|
||||
!nb::isinstance<Device::DeviceType>(other)) {
|
||||
return false;
|
||||
}
|
||||
return d == nb::cast<Device>(other);
|
||||
});
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const mx::Device::DeviceType& d, const nb::object& other) {
|
||||
if (!nb::isinstance<mx::Device>(other) &&
|
||||
!nb::isinstance<mx::Device::DeviceType>(other)) {
|
||||
return false;
|
||||
}
|
||||
return d == nb::cast<mx::Device>(other);
|
||||
});
|
||||
|
||||
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_ro("type", &Device::type)
|
||||
device_class
|
||||
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_ro("type", &mx::Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Device& d) {
|
||||
[](const mx::Device& d) {
|
||||
std::ostringstream os;
|
||||
os << d;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Device& d, const nb::object& other) {
|
||||
if (!nb::isinstance<Device>(other) &&
|
||||
!nb::isinstance<Device::DeviceType>(other)) {
|
||||
.def("__eq__", [](const mx::Device& d, const nb::object& other) {
|
||||
if (!nb::isinstance<mx::Device>(other) &&
|
||||
!nb::isinstance<mx::Device::DeviceType>(other)) {
|
||||
return false;
|
||||
}
|
||||
return d == nb::cast<Device>(other);
|
||||
return d == nb::cast<mx::Device>(other);
|
||||
});
|
||||
|
||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
||||
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
|
||||
|
||||
m.def(
|
||||
"default_device",
|
||||
&default_device,
|
||||
&mx::default_device,
|
||||
R"pbdoc(Get the default device.)pbdoc");
|
||||
m.def(
|
||||
"set_default_device",
|
||||
&set_default_device,
|
||||
&mx::set_default_device,
|
||||
"device"_a,
|
||||
R"pbdoc(Set the default device.)pbdoc");
|
||||
}
|
||||
|
@ -9,26 +9,27 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_distributed(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"distributed", "mlx.core.distributed: Communication operations");
|
||||
|
||||
nb::class_<distributed::Group>(
|
||||
nb::class_<mx::distributed::Group>(
|
||||
m,
|
||||
"Group",
|
||||
R"pbcopy(
|
||||
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
||||
processes that can communicate.
|
||||
)pbcopy")
|
||||
.def("rank", &distributed::Group::rank, "Get the rank of this process")
|
||||
.def("size", &distributed::Group::size, "Get the size of the group")
|
||||
.def(
|
||||
"rank", &mx::distributed::Group::rank, "Get the rank of this process")
|
||||
.def("size", &mx::distributed::Group::size, "Get the size of the group")
|
||||
.def(
|
||||
"split",
|
||||
&distributed::Group::split,
|
||||
&mx::distributed::Group::split,
|
||||
"color"_a,
|
||||
"key"_a = -1,
|
||||
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
||||
@ -48,14 +49,14 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
&distributed::is_available,
|
||||
&mx::distributed::is_available,
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"init",
|
||||
&distributed::init,
|
||||
&mx::distributed::init,
|
||||
"strict"_a = false,
|
||||
nb::sig("def init(strict: bool = False) -> Group"),
|
||||
R"pbdoc(
|
||||
@ -72,7 +73,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_sum",
|
||||
&distributed::all_sum,
|
||||
&mx::distributed::all_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@ -98,7 +99,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&distributed::all_gather,
|
||||
&mx::distributed::all_gather,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@ -125,7 +126,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&distributed::send,
|
||||
&mx::distributed::send,
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
@ -152,7 +153,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"recv",
|
||||
&distributed::recv,
|
||||
&mx::distributed::recv,
|
||||
"shape"_a,
|
||||
"dtype"_a,
|
||||
"src"_a,
|
||||
@ -181,7 +182,7 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&distributed::recv_like,
|
||||
&mx::distributed::recv_like,
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
|
@ -13,9 +13,9 @@
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
@ -23,7 +23,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&fast::rms_norm,
|
||||
&mx::fast::rms_norm,
|
||||
"x"_a,
|
||||
"weight"_a,
|
||||
"eps"_a,
|
||||
@ -49,7 +49,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"layer_norm",
|
||||
&fast::layer_norm,
|
||||
&mx::fast::layer_norm,
|
||||
"x"_a,
|
||||
"weight"_a.none(),
|
||||
"bias"_a.none(),
|
||||
@ -79,7 +79,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"rope",
|
||||
&fast::rope,
|
||||
&mx::fast::rope,
|
||||
"a"_a,
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
@ -114,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"scaled_dot_product_attention",
|
||||
&fast::scaled_dot_product_attention,
|
||||
&mx::fast::scaled_dot_product_attention,
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
@ -170,7 +170,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
auto kernel = fast::metal_kernel(
|
||||
auto kernel = mx::fast::metal_kernel(
|
||||
name,
|
||||
input_names,
|
||||
output_names,
|
||||
@ -182,7 +182,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<
|
||||
@ -190,12 +190,12 @@ void init_fast(nb::module_& parent_module) {
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {}) {
|
||||
std::vector<array> inputs;
|
||||
mx::StreamOrDevice s = {}) {
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, fast::TemplateArg>>
|
||||
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
|
||||
template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
@ -206,8 +206,8 @@ void init_fast(nb::module_& parent_module) {
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@ -9,24 +9,23 @@
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fft(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||
m.def(
|
||||
"fft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::fft(a, n.value(), axis, s);
|
||||
return mx::fft::fft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::fft(a, axis, s);
|
||||
return mx::fft::fft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::ifft(a, n.value(), axis, s);
|
||||
return mx::fft::ifft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::ifft(a, axis, s);
|
||||
return mx::fft::ifft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
return mx::fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
return mx::fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
return mx::fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
return mx::fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::fftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::fftn(a, axes.value(), s);
|
||||
return mx::fft::fftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::fftn(a, s);
|
||||
return mx::fft::fftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::ifftn(a, axes.value(), s);
|
||||
return mx::fft::ifftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::ifftn(a, s);
|
||||
return mx::fft::ifftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::rfft(a, n.value(), axis, s);
|
||||
return mx::fft::rfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::rfft(a, axis, s);
|
||||
return mx::fft::rfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<int>& n,
|
||||
int axis,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (n.has_value()) {
|
||||
return fft::irfft(a, n.value(), axis, s);
|
||||
return mx::fft::irfft(a, n.value(), axis, s);
|
||||
} else {
|
||||
return fft::irfft(a, axis, s);
|
||||
return mx::fft::irfft(a, axis, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
return mx::fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
return mx::fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
return mx::fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
return mx::fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::rfftn(a, axes.value(), s);
|
||||
return mx::fft::rfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::rfftn(a, s);
|
||||
return mx::fft::rfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
||||
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||
} else if (axes.has_value()) {
|
||||
return fft::irfftn(a, axes.value(), s);
|
||||
return mx::fft::irfftn(a, axes.value(), s);
|
||||
} else if (n.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||
} else {
|
||||
return fft::irfftn(a, s);
|
||||
return mx::fft::irfftn(a, s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
|
@ -43,20 +43,20 @@ void get_slice_params(
|
||||
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
}
|
||||
|
||||
array get_int_index(nb::object idx, int axis_size) {
|
||||
mx::array get_int_index(nb::object idx, int axis_size) {
|
||||
int idx_ = nb::cast<int>(idx);
|
||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||
|
||||
return array(idx_, uint32);
|
||||
return mx::array(idx_, mx::uint32);
|
||||
}
|
||||
|
||||
bool is_valid_index_type(const nb::object& obj) {
|
||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
|
||||
nb::isinstance<nb::list>(obj);
|
||||
nb::isinstance<mx::array>(obj) || obj.is_none() ||
|
||||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||
mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -77,14 +77,14 @@ array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||
return slice(src, starts, ends, strides);
|
||||
}
|
||||
|
||||
array mlx_get_item_array(const array& src, const array& indices) {
|
||||
mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
if (indices.dtype() == bool_) {
|
||||
if (indices.dtype() == mx::bool_) {
|
||||
throw std::invalid_argument("boolean indices are not yet supported");
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
|
||||
return take(src, indices, 0);
|
||||
}
|
||||
|
||||
array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
||||
mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -105,13 +105,13 @@ array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
||||
return take(src, get_int_index(idx, src.shape(0)), 0);
|
||||
}
|
||||
|
||||
array mlx_gather_nd(
|
||||
array src,
|
||||
mx::array mlx_gather_nd(
|
||||
mx::array src,
|
||||
const std::vector<nb::object>& indices,
|
||||
bool gather_first,
|
||||
int& max_dims) {
|
||||
max_dims = 0;
|
||||
std::vector<array> gather_indices;
|
||||
std::vector<mx::array> gather_indices;
|
||||
std::vector<bool> is_slice(indices.size(), false);
|
||||
int num_slices = 0;
|
||||
// gather all the arrays
|
||||
@ -127,13 +127,13 @@ array mlx_gather_nd(
|
||||
start = (start < 0) ? start + src.shape(i) : start;
|
||||
end = (end < 0) ? end + src.shape(i) : end;
|
||||
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
gather_indices.push_back(arange(start, end, stride, mx::uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
} else if (nb::isinstance<nb::int_>(idx)) {
|
||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
auto arr = nb::cast<array>(idx);
|
||||
} else if (nb::isinstance<mx::array>(idx)) {
|
||||
auto arr = nb::cast<mx::array>(idx);
|
||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||
gather_indices.push_back(arr);
|
||||
}
|
||||
@ -144,7 +144,7 @@ array mlx_gather_nd(
|
||||
int slice_index = 0;
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (is_slice[i]) {
|
||||
Shape index_shape(max_dims + num_slices, 1);
|
||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||
slice_index++;
|
||||
@ -158,7 +158,7 @@ array mlx_gather_nd(
|
||||
// reshape them so that the int/array indices are last
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
if (i < num_slices) {
|
||||
Shape index_shape(max_dims + num_slices, 1);
|
||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[i] = gather_indices[i].shape(0);
|
||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||
}
|
||||
@ -241,7 +241,7 @@ auto mlx_expand_ellipsis(
|
||||
return std::make_pair(non_none_indices, indices);
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
return src;
|
||||
@ -281,7 +281,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (nb::isinstance<array>(idx) || (nb::isinstance<nb::int_>(idx))) {
|
||||
if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
@ -294,7 +294,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
|
||||
int n_arr = 0;
|
||||
for (auto& idx : indices) {
|
||||
n_arr += nb::isinstance<array>(idx);
|
||||
n_arr += nb::isinstance<mx::array>(idx);
|
||||
}
|
||||
|
||||
have_array &= n_arr > 0;
|
||||
@ -304,7 +304,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -340,7 +340,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
} else if (idx.is_none()) {
|
||||
remaining_indices.push_back(idx);
|
||||
@ -426,11 +426,11 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item(const array& src, const nb::object& obj) {
|
||||
mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, nb::cast<array>(obj));
|
||||
} else if (nb::isinstance<mx::array>(obj)) {
|
||||
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
@ -448,10 +448,11 @@ array mlx_get_item(const array& src, const nb::object& obj) {
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||
const array& src,
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_int(
|
||||
const mx::array& src,
|
||||
const nb::int_& idx,
|
||||
const array& update) {
|
||||
const mx::array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
@ -473,10 +474,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||
{0}};
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||
const array& src,
|
||||
const array& indices,
|
||||
const array& update) {
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_array(
|
||||
const mx::array& src,
|
||||
const mx::array& indices,
|
||||
const mx::array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
@ -500,10 +502,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||
return {{indices}, up, {0}};
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
const array& src,
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_slice(
|
||||
const mx::array& src,
|
||||
const nb::slice& in_slice,
|
||||
const array& update) {
|
||||
const mx::array& update) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -539,7 +542,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
auto up = reshape(update, up_shape);
|
||||
|
||||
// Build array to mark start of slice
|
||||
auto idx = array({start}, {1}, uint32);
|
||||
auto idx = mx::array({start}, {1}, mx::uint32);
|
||||
|
||||
// Get slice size
|
||||
int slice_size = (end - start);
|
||||
@ -551,20 +554,21 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
|
||||
up = broadcast_to(up, up_shape_broadcast);
|
||||
|
||||
auto indices = std::vector<array>{idx};
|
||||
auto indices = std::vector<mx::array>{idx};
|
||||
auto axes = std::vector<int>{0};
|
||||
|
||||
return {indices, up, axes};
|
||||
}
|
||||
|
||||
return mlx_scatter_args_array(
|
||||
src, arange(start, end, stride, uint32), update);
|
||||
src, arange(start, end, stride, mx::uint32), update);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
const array& src,
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_nd(
|
||||
const mx::array& src,
|
||||
const nb::tuple& entries,
|
||||
const array& update) {
|
||||
const mx::array& update) {
|
||||
// Expand ellipses into a series of ':' slices
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
@ -623,12 +627,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
num_simple_slices_post++;
|
||||
}
|
||||
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
} else if (nb::isinstance<mx::array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
arrays_first = true;
|
||||
}
|
||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
||||
max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
num_simple_slices_post = 0;
|
||||
}
|
||||
@ -643,7 +647,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
|
||||
|
||||
// Go over each index type and translate to the needed scatter args
|
||||
std::vector<array> arr_indices;
|
||||
std::vector<mx::array> arr_indices;
|
||||
int slice_num = 0;
|
||||
int array_num = 0;
|
||||
int ax = 0;
|
||||
@ -668,7 +672,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||
auto idx = array({start}, idx_shape, uint32);
|
||||
auto idx = mx::array({start}, idx_shape, mx::uint32);
|
||||
slice_shapes.push_back(end - start);
|
||||
arr_indices.push_back(idx);
|
||||
|
||||
@ -677,7 +681,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
}
|
||||
// Otherwise we expand the slice into indices using arange
|
||||
else {
|
||||
auto idx = arange(start, end, stride, uint32);
|
||||
auto idx = arange(start, end, stride, mx::uint32);
|
||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
@ -696,9 +700,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
} else if (pyidx.is_none()) {
|
||||
// We only use the None's for bookeeping dimensions
|
||||
slice_num++;
|
||||
} else if (nb::isinstance<array>(pyidx)) {
|
||||
} else if (nb::isinstance<mx::array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = nb::cast<array>(pyidx);
|
||||
auto idx = nb::cast<mx::array>(pyidx);
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
|
||||
// Place the arrays in the correct dimension
|
||||
@ -748,16 +752,16 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
return {arr_indices, up, axes};
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>>
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_compute_scatter_args(
|
||||
const array& src,
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
|
||||
} else if (nb::isinstance<mx::array>(obj)) {
|
||||
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
@ -773,7 +777,7 @@ mlx_compute_scatter_args(
|
||||
}
|
||||
|
||||
auto mlx_slice_update(
|
||||
const array& src,
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
// Can't route to slice update if not slice or tuple
|
||||
@ -784,7 +788,7 @@ auto mlx_slice_update(
|
||||
if (nb::isinstance<nb::tuple>(obj)) {
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
|
||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
@ -881,7 +885,10 @@ auto mlx_slice_update(
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
||||
void mlx_set_item(
|
||||
mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [success, out] = mlx_slice_update(src, obj, v);
|
||||
if (success) {
|
||||
src.overwrite_descriptor(out);
|
||||
@ -897,8 +904,8 @@ void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
mx::array mlx_add_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
@ -909,8 +916,8 @@ array mlx_add_item(
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
mx::array mlx_subtract_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
@ -921,8 +928,8 @@ array mlx_subtract_item(
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
mx::array mlx_multiply_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
@ -933,8 +940,8 @@ array mlx_multiply_item(
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
mx::array mlx_divide_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
@ -945,8 +952,8 @@ array mlx_divide_item(
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
mx::array mlx_maximum_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
@ -957,8 +964,8 @@ array mlx_maximum_item(
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
mx::array mlx_minimum_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
|
@ -7,32 +7,35 @@
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
array mlx_get_item(const array& src, const nb::object& obj);
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
mx::array mlx_get_item(const mx::array& src, const nb::object& obj);
|
||||
void mlx_set_item(
|
||||
mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
mx::array mlx_add_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
mx::array mlx_subtract_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
mx::array mlx_multiply_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
mx::array mlx_divide_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
mx::array mlx_maximum_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
mx::array mlx_minimum_item(
|
||||
const mx::array& src,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
|
@ -10,15 +10,13 @@
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
namespace {
|
||||
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto result = svd(a, s);
|
||||
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"norm",
|
||||
[](const array& a,
|
||||
[](const mx::array& a,
|
||||
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||
const bool keepdims,
|
||||
const StreamOrDevice stream) {
|
||||
const mx::StreamOrDevice stream) {
|
||||
std::optional<std::vector<int>> axis = std::nullopt;
|
||||
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||
axis = std::vector<int>{*pv};
|
||||
@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
|
||||
}
|
||||
|
||||
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||
return norm(a, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, axis, keepdims, stream);
|
||||
} else {
|
||||
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||
return norm(a, *pv, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, *pv, axis, keepdims, stream);
|
||||
}
|
||||
double ord;
|
||||
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||
@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
} else {
|
||||
ord = std::get<double>(ord_);
|
||||
}
|
||||
return norm(a, ord, axis, keepdims, stream);
|
||||
return mx::linalg::norm(a, ord, axis, keepdims, stream);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"qr",
|
||||
&qr,
|
||||
&mx::linalg::qr,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"inv",
|
||||
&inv,
|
||||
&mx::linalg::inv,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tri_inv",
|
||||
&tri_inv,
|
||||
&mx::linalg::tri_inv,
|
||||
"a"_a,
|
||||
"upper"_a,
|
||||
nb::kw_only(),
|
||||
@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cholesky",
|
||||
&cholesky,
|
||||
&mx::linalg::cholesky,
|
||||
"a"_a,
|
||||
"upper"_a = false,
|
||||
nb::kw_only(),
|
||||
@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cholesky_inv",
|
||||
&cholesky_inv,
|
||||
&mx::linalg::cholesky_inv,
|
||||
"a"_a,
|
||||
"upper"_a = false,
|
||||
nb::kw_only(),
|
||||
@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"pinv",
|
||||
&pinv,
|
||||
&mx::linalg::pinv,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"cross",
|
||||
&cross,
|
||||
&mx::linalg::cross,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
"axis"_a = -1,
|
||||
@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eigvalsh",
|
||||
&eigvalsh,
|
||||
&mx::linalg::eigvalsh,
|
||||
"a"_a,
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"eigh",
|
||||
[](const array& a, const std::string UPLO, StreamOrDevice s) {
|
||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||
// TODO avoid cast?
|
||||
auto result = eigh(a, UPLO, s);
|
||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||
return nb::make_tuple(result.first, result.second);
|
||||
},
|
||||
"a"_a,
|
||||
|
@ -14,9 +14,9 @@
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
@ -86,7 +86,7 @@ class ZipFileWrapper {
|
||||
// Loading
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileReader : public io::Reader {
|
||||
class PyFileReader : public mx::io::Reader {
|
||||
public:
|
||||
PyFileReader(nb::object file)
|
||||
: pyistream_(file),
|
||||
@ -168,14 +168,14 @@ class PyFileReader : public io::Reader {
|
||||
};
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, mx::array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
return mx::load_safetensors(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
return mx::load_gguf(nb::cast<std::string>(file), s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||
nb::object file,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
bool own_file = nb::isinstance<nb::str>(file);
|
||||
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
@ -209,7 +209,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
"opened with zipfile.ZipFile");
|
||||
}
|
||||
// Output dictionary filename in zip -> loaded array
|
||||
std::unordered_map<std::string, array> array_dict;
|
||||
std::unordered_map<std::string, mx::array> array_dict;
|
||||
|
||||
// Create python ZipFile object
|
||||
ZipFileWrapper zipfile_object(zipfile, file);
|
||||
@ -218,7 +218,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
nb::object sub_file = zipfile_object.open(st);
|
||||
|
||||
// Create array from python file stream
|
||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
|
||||
// Remove .npy from file if it is there
|
||||
auto key = st;
|
||||
@ -240,12 +240,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
return array_dict;
|
||||
}
|
||||
|
||||
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
|
||||
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||
return load(nb::cast<std::string>(file), s);
|
||||
return mx::load(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
arr.eval();
|
||||
@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper(
|
||||
nb::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
mx::StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper(
|
||||
// Saving
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class PyFileWriter : public io::Writer {
|
||||
class PyFileWriter : public mx::io::Writer {
|
||||
public:
|
||||
PyFileWriter(nb::object file)
|
||||
: pyostream_(file),
|
||||
@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer {
|
||||
nb::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(nb::object file, array a) {
|
||||
void mlx_save_helper(nb::object file, mx::array a) {
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
save(nb::cast<std::string>(file), a);
|
||||
mx::save(nb::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
save(writer, a);
|
||||
mx::save(writer, a);
|
||||
}
|
||||
|
||||
return;
|
||||
@ -419,8 +419,9 @@ void mlx_savez_helper(
|
||||
}
|
||||
|
||||
// Collect args and kwargs
|
||||
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
|
||||
auto arrays_list = nb::cast<std::vector<array>>(args);
|
||||
auto arrays_dict =
|
||||
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
|
||||
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
@ -447,7 +448,7 @@ void mlx_savez_helper(
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
save(writer, a);
|
||||
mx::save(writer, a);
|
||||
}
|
||||
}
|
||||
|
||||
@ -470,17 +471,18 @@ void mlx_save_safetensor_helper(
|
||||
} else {
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_safetensors(
|
||||
nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(writer, arrays_map, metadata_map);
|
||||
mx::save_safetensors(writer, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@ -492,19 +494,20 @@ void mlx_save_gguf_helper(
|
||||
nb::object file,
|
||||
nb::dict a,
|
||||
std::optional<nb::dict> m) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
|
||||
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
|
||||
m.value());
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -14,22 +14,24 @@
|
||||
#include <variant>
|
||||
#include "mlx/io.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
using LoadOutputTypes = std::variant<
|
||||
array,
|
||||
std::unordered_map<std::string, array>,
|
||||
SafetensorsLoad,
|
||||
GGUFLoad>;
|
||||
mx::array,
|
||||
std::unordered_map<std::string, mx::array>,
|
||||
mx::SafetensorsLoad,
|
||||
mx::GGUFLoad>;
|
||||
|
||||
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
|
||||
mx::SafetensorsLoad mlx_load_safetensor_helper(
|
||||
nb::object file,
|
||||
mx::StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(
|
||||
nb::object file,
|
||||
nb::dict d,
|
||||
std::optional<nb::dict> m);
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
|
||||
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s);
|
||||
|
||||
void mlx_save_gguf_helper(
|
||||
nb::object file,
|
||||
@ -40,8 +42,8 @@ LoadOutputTypes mlx_load_helper(
|
||||
nb::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(nb::object file, array a);
|
||||
mx::StreamOrDevice s);
|
||||
void mlx_save_helper(nb::object file, mx::array a);
|
||||
void mlx_savez_helper(
|
||||
nb::object file,
|
||||
nb::args args,
|
||||
|
@ -8,22 +8,21 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(nb::module_& m) {
|
||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def(
|
||||
"is_available",
|
||||
&metal::is_available,
|
||||
&mx::metal::is_available,
|
||||
R"pbdoc(
|
||||
Check if the Metal back-end is available.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_active_memory",
|
||||
&metal::get_active_memory,
|
||||
&mx::metal::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
@ -32,7 +31,7 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_peak_memory",
|
||||
&metal::get_peak_memory,
|
||||
&mx::metal::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
@ -41,13 +40,13 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"reset_peak_memory",
|
||||
&metal::reset_peak_memory,
|
||||
&mx::metal::reset_peak_memory,
|
||||
R"pbdoc(
|
||||
Reset the peak memory to zero.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_cache_memory",
|
||||
&metal::get_cache_memory,
|
||||
&mx::metal::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
@ -56,7 +55,7 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_memory_limit",
|
||||
&metal::set_memory_limit,
|
||||
&mx::metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
nb::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
@ -81,7 +80,7 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_limit",
|
||||
&metal::set_cache_limit,
|
||||
&mx::metal::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
@ -101,7 +100,7 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_wired_limit",
|
||||
&metal::set_wired_limit,
|
||||
&mx::metal::set_wired_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the wired size limit.
|
||||
@ -133,7 +132,7 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"clear_cache",
|
||||
&metal::clear_cache,
|
||||
&mx::metal::clear_cache,
|
||||
R"pbdoc(
|
||||
Clear the memory cache.
|
||||
|
||||
@ -142,7 +141,7 @@ void init_metal(nb::module_& m) {
|
||||
|
||||
metal.def(
|
||||
"start_capture",
|
||||
&metal::start_capture,
|
||||
&mx::metal::start_capture,
|
||||
"path"_a,
|
||||
R"pbdoc(
|
||||
Start a Metal capture.
|
||||
@ -153,13 +152,13 @@ void init_metal(nb::module_& m) {
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"stop_capture",
|
||||
&metal::stop_capture,
|
||||
&mx::metal::stop_capture,
|
||||
R"pbdoc(
|
||||
Stop a Metal capture.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"device_info",
|
||||
&metal::device_info,
|
||||
&mx::metal::device_info,
|
||||
R"pbdoc(
|
||||
Get information about the GPU device and system settings.
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -12,23 +12,22 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
class PyKeySequence {
|
||||
public:
|
||||
explicit PyKeySequence(uint64_t seed) {
|
||||
state_.append(key(seed));
|
||||
state_.append(mx::random::key(seed));
|
||||
}
|
||||
|
||||
void seed(uint64_t seed) {
|
||||
state_[0] = key(seed);
|
||||
state_[0] = mx::random::key(seed);
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(nb::cast<array>(state_[0]));
|
||||
mx::array next() {
|
||||
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"key",
|
||||
&key,
|
||||
&mx::random::key,
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Get a PRNG key from a seed.
|
||||
@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
|
||||
&mx::random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return uniform(
|
||||
return mx::random::uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
shape,
|
||||
type.value_or(float32),
|
||||
type.value_or(mx::float32),
|
||||
key,
|
||||
s);
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"multivariate_normal",
|
||||
[](const array& mean,
|
||||
const array& cov,
|
||||
[](const mx::array& mean,
|
||||
const mx::array& cov,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return multivariate_normal(
|
||||
mean, cov, shape, type.value_or(float32), key, s);
|
||||
return mx::random::multivariate_normal(
|
||||
mean, cov, shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
return mx::random::randint(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
shape,
|
||||
type.value_or(mx::int32),
|
||||
key,
|
||||
s);
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = int32,
|
||||
"dtype"_a.none() = mx::int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
return mx::random::bernoulli(p, shape.value(), key, s);
|
||||
} else {
|
||||
return bernoulli(p, key, s);
|
||||
return mx::random::bernoulli(p, key, s);
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
auto t = type.value_or(mx::float32);
|
||||
if (shape_.has_value()) {
|
||||
return truncated_normal(lower, upper, shape_.value(), t, key, s);
|
||||
return mx::random::truncated_normal(
|
||||
lower, upper, shape_.value(), t, key, s);
|
||||
} else {
|
||||
return truncated_normal(lower, upper, t, key, s);
|
||||
return mx::random::truncated_normal(lower, upper, t, key, s);
|
||||
}
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = nb::none(),
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"categorical",
|
||||
[](const array& logits,
|
||||
[](const mx::array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
} else if (shape.has_value()) {
|
||||
return categorical(logits, axis, shape.value(), key, s);
|
||||
return mx::random::categorical(logits, axis, shape.value(), key, s);
|
||||
} else if (num_samples.has_value()) {
|
||||
return categorical(logits, axis, num_samples.value(), key, s);
|
||||
return mx::random::categorical(
|
||||
logits, axis, num_samples.value(), key, s);
|
||||
} else {
|
||||
return categorical(logits, axis, key, s);
|
||||
return mx::random::categorical(logits, axis, key, s);
|
||||
}
|
||||
},
|
||||
"logits"_a,
|
||||
@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"laplace",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return laplace(shape, type.value_or(float32), loc, scale, key, s);
|
||||
return mx::random::laplace(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a.none() = float32,
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permuation",
|
||||
[](const std::variant<nb::int_, array>& x,
|
||||
[](const std::variant<nb::int_, mx::array>& x,
|
||||
int axis,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
||||
return permutation(nb::cast<int>(*pv), key, s);
|
||||
return mx::random::permutation(nb::cast<int>(*pv), key, s);
|
||||
} else {
|
||||
return permutation(std::get<array>(x), axis, key, s);
|
||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
|
@ -10,14 +10,14 @@
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Create the StreamContext on enter and delete on exit.
|
||||
class PyStreamContext {
|
||||
public:
|
||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||
PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
@ -26,7 +26,7 @@ class PyStreamContext {
|
||||
}
|
||||
|
||||
void enter() {
|
||||
_inner = new StreamContext(_s);
|
||||
_inner = new mx::StreamContext(_s);
|
||||
}
|
||||
|
||||
void exit() {
|
||||
@ -37,39 +37,40 @@ class PyStreamContext {
|
||||
}
|
||||
|
||||
private:
|
||||
StreamOrDevice _s;
|
||||
StreamContext* _inner;
|
||||
mx::StreamOrDevice _s;
|
||||
mx::StreamContext* _inner;
|
||||
};
|
||||
|
||||
void init_stream(nb::module_& m) {
|
||||
nb::class_<Stream>(
|
||||
nb::class_<mx::Stream>(
|
||||
m,
|
||||
"Stream",
|
||||
R"pbdoc(
|
||||
A stream for running operations on a given device.
|
||||
)pbdoc")
|
||||
.def_ro("device", &Stream::device)
|
||||
.def_ro("device", &mx::Stream::device)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Stream& s) {
|
||||
[](const mx::Stream& s) {
|
||||
std::ostringstream os;
|
||||
os << s;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Stream& s, const nb::object& other) {
|
||||
return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
|
||||
.def("__eq__", [](const mx::Stream& s, const nb::object& other) {
|
||||
return nb::isinstance<mx::Stream>(other) &&
|
||||
s == nb::cast<mx::Stream>(other);
|
||||
});
|
||||
|
||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
||||
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
|
||||
|
||||
m.def(
|
||||
"default_stream",
|
||||
&default_stream,
|
||||
&mx::default_stream,
|
||||
"device"_a,
|
||||
R"pbdoc(Get the device's default stream.)pbdoc");
|
||||
m.def(
|
||||
"set_default_stream",
|
||||
&set_default_stream,
|
||||
&mx::set_default_stream,
|
||||
"stream"_a,
|
||||
R"pbdoc(
|
||||
Set the default stream.
|
||||
@ -82,7 +83,7 @@ void init_stream(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"new_stream",
|
||||
&new_stream,
|
||||
&mx::new_stream,
|
||||
"device"_a,
|
||||
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
||||
|
||||
@ -94,7 +95,7 @@ void init_stream(nb::module_& m) {
|
||||
Args:
|
||||
s: The stream or device to set as the default.
|
||||
)pbdoc")
|
||||
.def(nb::init<StreamOrDevice>(), "s"_a)
|
||||
.def(nb::init<mx::StreamOrDevice>(), "s"_a)
|
||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||
.def(
|
||||
"__exit__",
|
||||
@ -107,7 +108,7 @@ void init_stream(nb::module_& m) {
|
||||
"traceback"_a = nb::none());
|
||||
m.def(
|
||||
"stream",
|
||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||
[](mx::StreamOrDevice s) { return PyStreamContext(s); },
|
||||
"s"_a,
|
||||
R"pbdoc(
|
||||
Create a context manager to set the default device and stream.
|
||||
@ -131,8 +132,8 @@ void init_stream(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"synchronize",
|
||||
[](const std::optional<Stream>& s) {
|
||||
s ? synchronize(s.value()) : synchronize();
|
||||
[](const std::optional<mx::Stream>& s) {
|
||||
s ? mx::synchronize(s.value()) : mx::synchronize();
|
||||
},
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
|
@ -20,9 +20,12 @@
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Needed for printing shapes and strides.
|
||||
using mx::operator<<;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
@ -108,7 +111,7 @@ auto py_value_and_grad(
|
||||
}
|
||||
|
||||
// Collect the arrays
|
||||
std::vector<array> arrays;
|
||||
std::vector<mx::array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
for (auto i : argnums) {
|
||||
auto argsi = tree_flatten(args[i]);
|
||||
@ -127,7 +130,7 @@ auto py_value_and_grad(
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
auto value_and_grads = mx::value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
&kwargs,
|
||||
@ -136,7 +139,7 @@ auto py_value_and_grad(
|
||||
&counts,
|
||||
&py_value_out,
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
scalar_func_only](const std::vector<mx::array>& a) {
|
||||
// Copy the arguments
|
||||
nb::list args_cpy;
|
||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||
@ -165,7 +168,7 @@ auto py_value_and_grad(
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<array>(py_value_out)) {
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
@ -193,7 +196,7 @@ auto py_value_and_grad(
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!nb::isinstance<array>(ret[0])) {
|
||||
if (!nb::isinstance<mx::array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
@ -275,12 +278,12 @@ auto py_vmap(
|
||||
{tree, axes},
|
||||
[&flat_axes, &encountered_tuple, output_axes](
|
||||
const std::vector<nb::object>& inputs) {
|
||||
if (nb::isinstance<array>(inputs[0])) {
|
||||
if (nb::isinstance<mx::array>(inputs[0])) {
|
||||
if (inputs[1].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
@ -297,7 +300,7 @@ auto py_vmap(
|
||||
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
@ -323,7 +326,7 @@ auto py_vmap(
|
||||
"[vmap] The arguments should contain only arrays");
|
||||
}
|
||||
});
|
||||
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
||||
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
return flat_axes;
|
||||
@ -339,7 +342,7 @@ auto py_vmap(
|
||||
nb::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
@ -348,12 +351,12 @@ auto py_vmap(
|
||||
};
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
auto outputs = mx::detail::vmap_replace(
|
||||
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||
|
||||
// Put the outputs back in the container
|
||||
@ -401,7 +404,7 @@ struct PyCompiledFun {
|
||||
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
std::vector<mx::array> inputs;
|
||||
|
||||
// Compilation constants which includes the tree structure of the arguments
|
||||
std::vector<uint64_t> constants;
|
||||
@ -437,8 +440,8 @@ struct PyCompiledFun {
|
||||
constants.push_back(nb::cast<int64_t>(r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
inputs.push_back(nb::cast<array>(obj));
|
||||
} else if (nb::isinstance<mx::array>(obj)) {
|
||||
inputs.push_back(nb::cast<mx::array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (nb::isinstance<nb::str>(obj)) {
|
||||
auto r = obj.attr("__hash__")();
|
||||
@ -461,10 +464,10 @@ struct PyCompiledFun {
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
const std::vector<mx::array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
std::vector<mx::array> flat_in_captures;
|
||||
std::vector<mx::array> trace_captures;
|
||||
if (!captured_inputs.is_none()) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
@ -505,9 +508,9 @@ struct PyCompiledFun {
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<array> captures(
|
||||
std::vector<mx::array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
tree_fill(captured_outputs, captures);
|
||||
@ -526,7 +529,7 @@ struct PyCompiledFun {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
@ -561,7 +564,7 @@ class PyCheckpointedFun {
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
auto args = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
@ -579,7 +582,7 @@ class PyCheckpointedFun {
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
auto outputs = checkpoint(
|
||||
auto outputs = mx::checkpoint(
|
||||
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
||||
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
@ -660,12 +663,12 @@ class PyCustomFunction {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
std::vector<array> outputs;
|
||||
std::vector<mx::array> outputs;
|
||||
std::tie(outputs, *output_structure_) =
|
||||
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||
return outputs;
|
||||
@ -694,10 +697,10 @@ class PyCustomFunction {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<mx::array>& outputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
@ -734,9 +737,9 @@ class PyCustomFunction {
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> operator()(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
@ -759,7 +762,7 @@ class PyCustomFunction {
|
||||
int tangent_index = 0;
|
||||
auto new_tangents =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
if (nb::isinstance<array>(element) &&
|
||||
if (nb::isinstance<mx::array>(element) &&
|
||||
have_tangents[array_index++]) {
|
||||
return nb::cast(tangents[tangent_index++]);
|
||||
} else {
|
||||
@ -789,8 +792,8 @@ class PyCustomFunction {
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> operator()(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
@ -807,7 +810,7 @@ class PyCustomFunction {
|
||||
auto new_axes =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
int axis = axes[arr_index++];
|
||||
if (nb::isinstance<array>(element) && axis >= 0) {
|
||||
if (nb::isinstance<mx::array>(element) && axis >= 0) {
|
||||
return nb::cast(axis);
|
||||
} else {
|
||||
return nb::none();
|
||||
@ -831,11 +834,11 @@ class PyCustomFunction {
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
std::vector<mx::array> outputs;
|
||||
std::vector<int> output_axes;
|
||||
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||
if (nb::isinstance<array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<array>(objects[0]));
|
||||
if (nb::isinstance<mx::array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<mx::array>(objects[0]));
|
||||
output_axes.push_back(
|
||||
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||
}
|
||||
@ -852,7 +855,7 @@ class PyCustomFunction {
|
||||
}
|
||||
|
||||
// Extract the inputs and their structure in capturable vars
|
||||
std::vector<array> input_arrays;
|
||||
std::vector<mx::array> input_arrays;
|
||||
nb::object input_structure;
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
std::tie(input_arrays, input_structure) =
|
||||
@ -864,7 +867,7 @@ class PyCustomFunction {
|
||||
|
||||
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||
// backward pass. Then call it immediately and return the results.
|
||||
auto f = custom_function(
|
||||
auto f = mx::custom_function(
|
||||
InnerFunction(fun_, input_structure, output_structure),
|
||||
make_vjp_function(input_structure, output_structure),
|
||||
make_jvp_function(input_structure),
|
||||
@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"async_eval",
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
async_eval(arrays);
|
||||
@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
if (nb::isinstance<mx::array>(out)) {
|
||||
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||
} else {
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<mx::array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
if (nb::isinstance<mx::array>(out)) {
|
||||
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||
} else {
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<mx::array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
std::vector<mx::array> arrays = tree_flatten(args);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
&mx::disable_compile,
|
||||
R"pbdoc(
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
&mx::enable_compile,
|
||||
R"pbdoc(
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() {
|
||||
tree_cache().clear();
|
||||
detail::compile_clear_cache();
|
||||
mx::detail::compile_clear_cache();
|
||||
}));
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ void tree_visit_update(
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return nb::cast<nb::object>(d);
|
||||
} else if (nb::isinstance<array>(subtree)) {
|
||||
} else if (nb::isinstance<mx::array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return nb::cast<nb::object>(subtree);
|
||||
@ -200,7 +200,7 @@ void tree_visit_update(
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
||||
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
||||
@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
nb::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
const std::vector<mx::array>& src,
|
||||
const std::vector<mx::array>& dst) {
|
||||
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto arr = nb::cast<array>(node);
|
||||
auto arr = nb::cast<mx::array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return nb::cast(it->second);
|
||||
}
|
||||
@ -224,12 +224,12 @@ void tree_replace(
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<array> flat_tree;
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
|
||||
nb::object tree_unflatten(
|
||||
nb::object tree,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index /* = 0 */) {
|
||||
return tree_map(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
return nb::cast(values[index++]);
|
||||
} else {
|
||||
return nb::cast<nb::object>(obj);
|
||||
@ -265,16 +265,16 @@ nb::object structure_sentinel() {
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
|
||||
nb::object tree,
|
||||
bool strict /* = true */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
std::vector<mx::array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return nb::cast<nb::object>(obj);
|
||||
@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
|
||||
nb::object tree_unflatten_from_structure(
|
||||
nb::object structure,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index /* = 0 */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](nb::handle obj) {
|
||||
|
@ -4,8 +4,8 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
@ -27,7 +27,7 @@ void tree_visit_update(
|
||||
/**
|
||||
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||
* given arrays. */
|
||||
void tree_fill(nb::object& tree, const std::vector<array>& values);
|
||||
void tree_fill(nb::object& tree, const std::vector<mx::array>& values);
|
||||
|
||||
/**
|
||||
* Replace all the arrays from the src values with the dst values in the
|
||||
@ -35,28 +35,28 @@ void tree_fill(nb::object& tree, const std::vector<array>& values);
|
||||
*/
|
||||
void tree_replace(
|
||||
nb::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst);
|
||||
const std::vector<mx::array>& src,
|
||||
const std::vector<mx::array>& dst);
|
||||
|
||||
/**
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<array> tree_flatten(nb::object tree, bool strict = true);
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
*/
|
||||
nb::object tree_unflatten(
|
||||
nb::object tree,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index = 0);
|
||||
|
||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
|
||||
nb::object tree,
|
||||
bool strict = true);
|
||||
|
||||
nb::object tree_unflatten_from_structure(
|
||||
nb::object structure,
|
||||
const std::vector<array>& values,
|
||||
const std::vector<mx::array>& values,
|
||||
int index = 0);
|
||||
|
@ -4,22 +4,24 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/convert.h"
|
||||
|
||||
array to_array(
|
||||
mx::array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype /* = std::nullopt */) {
|
||||
std::optional<mx::Dtype> dtype /* = std::nullopt */) {
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
|
||||
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(int32);
|
||||
auto out_t = dtype.value_or(mx::int32);
|
||||
// bool_ is an exception and is always promoted
|
||||
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||
return mx::array(
|
||||
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
||||
auto out_t = dtype.value_or(mx::float32);
|
||||
return mx::array(
|
||||
nb::cast<float>(*pv),
|
||||
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||
return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
|
||||
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||
return *pv;
|
||||
} else if (auto pv = std::get_if<
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||
@ -30,7 +32,7 @@ array to_array(
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<array, array> to_arrays(
|
||||
std::pair<mx::array, mx::array> to_arrays(
|
||||
const ScalarOrArray& a,
|
||||
const ScalarOrArray& b) {
|
||||
// Four cases:
|
||||
@ -39,15 +41,15 @@ std::pair<array, array> to_arrays(
|
||||
// - If b is an array but a is not, treat a as a weak python type
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||
return std::holds_alternative<array>(x) ||
|
||||
return std::holds_alternative<mx::array>(x) ||
|
||||
std::holds_alternative<nb::object>(x) &&
|
||||
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
||||
};
|
||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||
if (auto px = std::get_if<array>(&x); px) {
|
||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||
return *px;
|
||||
} else {
|
||||
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||
}
|
||||
};
|
||||
|
||||
@ -66,11 +68,11 @@ std::pair<array, array> to_arrays(
|
||||
}
|
||||
}
|
||||
|
||||
array to_array_with_accessor(nb::object obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
return nb::cast<array>(obj);
|
||||
mx::array to_array_with_accessor(nb::object obj) {
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
return nb::cast<mx::array>(obj);
|
||||
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
||||
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
|
@ -12,17 +12,16 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray = std::variant<
|
||||
nb::bool_,
|
||||
nb::int_,
|
||||
nb::float_,
|
||||
// Must be above ndarray
|
||||
array,
|
||||
mx::array,
|
||||
// Must be above complex
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||
std::complex<float>,
|
||||
@ -45,7 +44,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
// Checks if the value can be compared to an array (or is already an
|
||||
// mlx array)
|
||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
||||
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||
} else {
|
||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||
// and can be compared to an array
|
||||
@ -66,12 +65,12 @@ inline void throw_invalid_operation(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
array to_array(
|
||||
mx::array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt);
|
||||
std::optional<mx::Dtype> dtype = std::nullopt);
|
||||
|
||||
std::pair<array, array> to_arrays(
|
||||
std::pair<mx::array, mx::array> to_arrays(
|
||||
const ScalarOrArray& a,
|
||||
const ScalarOrArray& b);
|
||||
|
||||
array to_array_with_accessor(nb::object obj);
|
||||
mx::array to_array_with_accessor(nb::object obj);
|
||||
|
Loading…
Reference in New Issue
Block a user