mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
Remove "using namespace mlx::core" in python/src (#1689)
This commit is contained in:
parent
f3dfa36a3a
commit
0bf19037ca
File diff suppressed because it is too large
Load Diff
@ -14,37 +14,37 @@
|
|||||||
#define Py_bf_releasebuffer 2
|
#define Py_bf_releasebuffer 2
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
std::string buffer_format(const array& a) {
|
std::string buffer_format(const mx::array& a) {
|
||||||
// https://docs.python.org/3.10/library/struct.html#format-characters
|
// https://docs.python.org/3.10/library/struct.html#format-characters
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case mx::bool_:
|
||||||
return "?";
|
return "?";
|
||||||
case uint8:
|
case mx::uint8:
|
||||||
return "B";
|
return "B";
|
||||||
case uint16:
|
case mx::uint16:
|
||||||
return "H";
|
return "H";
|
||||||
case uint32:
|
case mx::uint32:
|
||||||
return "I";
|
return "I";
|
||||||
case uint64:
|
case mx::uint64:
|
||||||
return "Q";
|
return "Q";
|
||||||
case int8:
|
case mx::int8:
|
||||||
return "b";
|
return "b";
|
||||||
case int16:
|
case mx::int16:
|
||||||
return "h";
|
return "h";
|
||||||
case int32:
|
case mx::int32:
|
||||||
return "i";
|
return "i";
|
||||||
case int64:
|
case mx::int64:
|
||||||
return "q";
|
return "q";
|
||||||
case float16:
|
case mx::float16:
|
||||||
return "e";
|
return "e";
|
||||||
case float32:
|
case mx::float32:
|
||||||
return "f";
|
return "f";
|
||||||
case bfloat16:
|
case mx::bfloat16:
|
||||||
return "B";
|
return "B";
|
||||||
case complex64:
|
case mx::complex64:
|
||||||
return "Zf\0";
|
return "Zf\0";
|
||||||
default: {
|
default: {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
@ -84,7 +84,7 @@ struct buffer_info {
|
|||||||
|
|
||||||
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
||||||
std::memset(view, 0, sizeof(Py_buffer));
|
std::memset(view, 0, sizeof(Py_buffer));
|
||||||
auto a = nb::cast<array>(nb::handle(obj));
|
auto a = nb::cast<mx::array>(nb::handle(obj));
|
||||||
|
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
|
@ -16,7 +16,7 @@ enum PyScalarT {
|
|||||||
|
|
||||||
namespace nanobind {
|
namespace nanobind {
|
||||||
template <>
|
template <>
|
||||||
struct ndarray_traits<float16_t> {
|
struct ndarray_traits<mx::float16_t> {
|
||||||
static constexpr bool is_complex = false;
|
static constexpr bool is_complex = false;
|
||||||
static constexpr bool is_float = true;
|
static constexpr bool is_float = true;
|
||||||
static constexpr bool is_bool = false;
|
static constexpr bool is_bool = false;
|
||||||
@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array nd_array_to_mlx_contiguous(
|
mx::array nd_array_to_mlx_contiguous(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
const Shape& shape,
|
const mx::Shape& shape,
|
||||||
Dtype dtype) {
|
mx::Dtype dtype) {
|
||||||
// Make a copy of the numpy buffer
|
// Make a copy of the numpy buffer
|
||||||
// Get buffer ptr pass to array constructor
|
// Get buffer ptr pass to array constructor
|
||||||
auto data_ptr = nd_array.data();
|
auto data_ptr = nd_array.data();
|
||||||
return array(static_cast<const T*>(data_ptr), shape, dtype);
|
return mx::array(static_cast<const T*>(data_ptr), shape, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
array nd_array_to_mlx(
|
mx::array nd_array_to_mlx(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
std::optional<Dtype> dtype) {
|
std::optional<mx::Dtype> dtype) {
|
||||||
// Compute the shape and size
|
// Compute the shape and size
|
||||||
Shape shape;
|
mx::Shape shape;
|
||||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||||
}
|
}
|
||||||
@ -59,49 +59,49 @@ array nd_array_to_mlx(
|
|||||||
// Copy data and make array
|
// Copy data and make array
|
||||||
if (type == nb::dtype<bool>()) {
|
if (type == nb::dtype<bool>()) {
|
||||||
return nd_array_to_mlx_contiguous<bool>(
|
return nd_array_to_mlx_contiguous<bool>(
|
||||||
nd_array, shape, dtype.value_or(bool_));
|
nd_array, shape, dtype.value_or(mx::bool_));
|
||||||
} else if (type == nb::dtype<uint8_t>()) {
|
} else if (type == nb::dtype<uint8_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<uint8_t>(
|
return nd_array_to_mlx_contiguous<uint8_t>(
|
||||||
nd_array, shape, dtype.value_or(uint8));
|
nd_array, shape, dtype.value_or(mx::uint8));
|
||||||
} else if (type == nb::dtype<uint16_t>()) {
|
} else if (type == nb::dtype<uint16_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<uint16_t>(
|
return nd_array_to_mlx_contiguous<uint16_t>(
|
||||||
nd_array, shape, dtype.value_or(uint16));
|
nd_array, shape, dtype.value_or(mx::uint16));
|
||||||
} else if (type == nb::dtype<uint32_t>()) {
|
} else if (type == nb::dtype<uint32_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<uint32_t>(
|
return nd_array_to_mlx_contiguous<uint32_t>(
|
||||||
nd_array, shape, dtype.value_or(uint32));
|
nd_array, shape, dtype.value_or(mx::uint32));
|
||||||
} else if (type == nb::dtype<uint64_t>()) {
|
} else if (type == nb::dtype<uint64_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<uint64_t>(
|
return nd_array_to_mlx_contiguous<uint64_t>(
|
||||||
nd_array, shape, dtype.value_or(uint64));
|
nd_array, shape, dtype.value_or(mx::uint64));
|
||||||
} else if (type == nb::dtype<int8_t>()) {
|
} else if (type == nb::dtype<int8_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<int8_t>(
|
return nd_array_to_mlx_contiguous<int8_t>(
|
||||||
nd_array, shape, dtype.value_or(int8));
|
nd_array, shape, dtype.value_or(mx::int8));
|
||||||
} else if (type == nb::dtype<int16_t>()) {
|
} else if (type == nb::dtype<int16_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<int16_t>(
|
return nd_array_to_mlx_contiguous<int16_t>(
|
||||||
nd_array, shape, dtype.value_or(int16));
|
nd_array, shape, dtype.value_or(mx::int16));
|
||||||
} else if (type == nb::dtype<int32_t>()) {
|
} else if (type == nb::dtype<int32_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<int32_t>(
|
return nd_array_to_mlx_contiguous<int32_t>(
|
||||||
nd_array, shape, dtype.value_or(int32));
|
nd_array, shape, dtype.value_or(mx::int32));
|
||||||
} else if (type == nb::dtype<int64_t>()) {
|
} else if (type == nb::dtype<int64_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<int64_t>(
|
return nd_array_to_mlx_contiguous<int64_t>(
|
||||||
nd_array, shape, dtype.value_or(int64));
|
nd_array, shape, dtype.value_or(mx::int64));
|
||||||
} else if (type == nb::dtype<float16_t>()) {
|
} else if (type == nb::dtype<mx::float16_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<float16_t>(
|
return nd_array_to_mlx_contiguous<mx::float16_t>(
|
||||||
nd_array, shape, dtype.value_or(float16));
|
nd_array, shape, dtype.value_or(mx::float16));
|
||||||
} else if (type == nb::bfloat16) {
|
} else if (type == nb::bfloat16) {
|
||||||
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
return nd_array_to_mlx_contiguous<mx::bfloat16_t>(
|
||||||
nd_array, shape, dtype.value_or(bfloat16));
|
nd_array, shape, dtype.value_or(mx::bfloat16));
|
||||||
} else if (type == nb::dtype<float>()) {
|
} else if (type == nb::dtype<float>()) {
|
||||||
return nd_array_to_mlx_contiguous<float>(
|
return nd_array_to_mlx_contiguous<float>(
|
||||||
nd_array, shape, dtype.value_or(float32));
|
nd_array, shape, dtype.value_or(mx::float32));
|
||||||
} else if (type == nb::dtype<double>()) {
|
} else if (type == nb::dtype<double>()) {
|
||||||
return nd_array_to_mlx_contiguous<double>(
|
return nd_array_to_mlx_contiguous<double>(
|
||||||
nd_array, shape, dtype.value_or(float32));
|
nd_array, shape, dtype.value_or(mx::float32));
|
||||||
} else if (type == nb::dtype<std::complex<float>>()) {
|
} else if (type == nb::dtype<std::complex<float>>()) {
|
||||||
return nd_array_to_mlx_contiguous<complex64_t>(
|
return nd_array_to_mlx_contiguous<mx::complex64_t>(
|
||||||
nd_array, shape, dtype.value_or(complex64));
|
nd_array, shape, dtype.value_or(mx::complex64));
|
||||||
} else if (type == nb::dtype<std::complex<double>>()) {
|
} else if (type == nb::dtype<std::complex<double>>()) {
|
||||||
return nd_array_to_mlx_contiguous<complex128_t>(
|
return nd_array_to_mlx_contiguous<mx::complex128_t>(
|
||||||
nd_array, shape, dtype.value_or(complex64));
|
nd_array, shape, dtype.value_or(mx::complex64));
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
|
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
|
||||||
}
|
}
|
||||||
@ -109,7 +109,7 @@ array nd_array_to_mlx(
|
|||||||
|
|
||||||
template <typename T, typename... NDParams>
|
template <typename T, typename... NDParams>
|
||||||
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||||
array a,
|
mx::array a,
|
||||||
std::optional<nb::dlpack::dtype> t = {}) {
|
std::optional<nb::dlpack::dtype> t = {}) {
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
@ -126,48 +126,48 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename... NDParams>
|
template <typename... NDParams>
|
||||||
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case mx::bool_:
|
||||||
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
||||||
case uint8:
|
case mx::uint8:
|
||||||
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
||||||
case uint16:
|
case mx::uint16:
|
||||||
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
||||||
case uint32:
|
case mx::uint32:
|
||||||
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
||||||
case uint64:
|
case mx::uint64:
|
||||||
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
||||||
case int8:
|
case mx::int8:
|
||||||
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
||||||
case int16:
|
case mx::int16:
|
||||||
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
||||||
case int32:
|
case mx::int32:
|
||||||
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
||||||
case int64:
|
case mx::int64:
|
||||||
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
||||||
case float16:
|
case mx::float16:
|
||||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
return mlx_to_nd_array_impl<mx::float16_t, NDParams...>(a);
|
||||||
case bfloat16:
|
case mx::bfloat16:
|
||||||
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||||
case float32:
|
case mx::float32:
|
||||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||||
case complex64:
|
case mx::complex64:
|
||||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||||
default:
|
default:
|
||||||
throw nb::type_error("type cannot be converted to NumPy.");
|
throw nb::type_error("type cannot be converted to NumPy.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a) {
|
||||||
return mlx_to_nd_array<nb::numpy>(a);
|
return mlx_to_nd_array<nb::numpy>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
nb::ndarray<> mlx_to_dlpack(const mx::array& a) {
|
||||||
return mlx_to_nd_array<>(a);
|
return mlx_to_nd_array<>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::object to_scalar(array& a) {
|
nb::object to_scalar(mx::array& a) {
|
||||||
if (a.size() != 1) {
|
if (a.size() != 1) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[convert] Only length-1 arrays can be converted to Python scalars.");
|
"[convert] Only length-1 arrays can be converted to Python scalars.");
|
||||||
@ -177,31 +177,31 @@ nb::object to_scalar(array& a) {
|
|||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case mx::bool_:
|
||||||
return nb::cast(a.item<bool>());
|
return nb::cast(a.item<bool>());
|
||||||
case uint8:
|
case mx::uint8:
|
||||||
return nb::cast(a.item<uint8_t>());
|
return nb::cast(a.item<uint8_t>());
|
||||||
case uint16:
|
case mx::uint16:
|
||||||
return nb::cast(a.item<uint16_t>());
|
return nb::cast(a.item<uint16_t>());
|
||||||
case uint32:
|
case mx::uint32:
|
||||||
return nb::cast(a.item<uint32_t>());
|
return nb::cast(a.item<uint32_t>());
|
||||||
case uint64:
|
case mx::uint64:
|
||||||
return nb::cast(a.item<uint64_t>());
|
return nb::cast(a.item<uint64_t>());
|
||||||
case int8:
|
case mx::int8:
|
||||||
return nb::cast(a.item<int8_t>());
|
return nb::cast(a.item<int8_t>());
|
||||||
case int16:
|
case mx::int16:
|
||||||
return nb::cast(a.item<int16_t>());
|
return nb::cast(a.item<int16_t>());
|
||||||
case int32:
|
case mx::int32:
|
||||||
return nb::cast(a.item<int32_t>());
|
return nb::cast(a.item<int32_t>());
|
||||||
case int64:
|
case mx::int64:
|
||||||
return nb::cast(a.item<int64_t>());
|
return nb::cast(a.item<int64_t>());
|
||||||
case float16:
|
case mx::float16:
|
||||||
return nb::cast(static_cast<float>(a.item<float16_t>()));
|
return nb::cast(static_cast<float>(a.item<mx::float16_t>()));
|
||||||
case float32:
|
case mx::float32:
|
||||||
return nb::cast(a.item<float>());
|
return nb::cast(a.item<float>());
|
||||||
case bfloat16:
|
case mx::bfloat16:
|
||||||
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
return nb::cast(static_cast<float>(a.item<mx::bfloat16_t>()));
|
||||||
case complex64:
|
case mx::complex64:
|
||||||
return nb::cast(a.item<std::complex<float>>());
|
return nb::cast(a.item<std::complex<float>>());
|
||||||
default:
|
default:
|
||||||
throw nb::type_error("type cannot be converted to Python scalar.");
|
throw nb::type_error("type cannot be converted to Python scalar.");
|
||||||
@ -209,7 +209,7 @@ nb::object to_scalar(array& a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U = T>
|
template <typename T, typename U = T>
|
||||||
nb::list to_list(array& a, size_t index, int dim) {
|
nb::list to_list(mx::array& a, size_t index, int dim) {
|
||||||
nb::list pl;
|
nb::list pl;
|
||||||
auto stride = a.strides()[dim];
|
auto stride = a.strides()[dim];
|
||||||
for (int i = 0; i < a.shape(dim); ++i) {
|
for (int i = 0; i < a.shape(dim); ++i) {
|
||||||
@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) {
|
|||||||
return pl;
|
return pl;
|
||||||
}
|
}
|
||||||
|
|
||||||
nb::object tolist(array& a) {
|
nb::object tolist(mx::array& a) {
|
||||||
if (a.ndim() == 0) {
|
if (a.ndim() == 0) {
|
||||||
return to_scalar(a);
|
return to_scalar(a);
|
||||||
}
|
}
|
||||||
@ -232,31 +232,31 @@ nb::object tolist(array& a) {
|
|||||||
a.eval();
|
a.eval();
|
||||||
}
|
}
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case bool_:
|
case mx::bool_:
|
||||||
return to_list<bool>(a, 0, 0);
|
return to_list<bool>(a, 0, 0);
|
||||||
case uint8:
|
case mx::uint8:
|
||||||
return to_list<uint8_t>(a, 0, 0);
|
return to_list<uint8_t>(a, 0, 0);
|
||||||
case uint16:
|
case mx::uint16:
|
||||||
return to_list<uint16_t>(a, 0, 0);
|
return to_list<uint16_t>(a, 0, 0);
|
||||||
case uint32:
|
case mx::uint32:
|
||||||
return to_list<uint32_t>(a, 0, 0);
|
return to_list<uint32_t>(a, 0, 0);
|
||||||
case uint64:
|
case mx::uint64:
|
||||||
return to_list<uint64_t>(a, 0, 0);
|
return to_list<uint64_t>(a, 0, 0);
|
||||||
case int8:
|
case mx::int8:
|
||||||
return to_list<int8_t>(a, 0, 0);
|
return to_list<int8_t>(a, 0, 0);
|
||||||
case int16:
|
case mx::int16:
|
||||||
return to_list<int16_t>(a, 0, 0);
|
return to_list<int16_t>(a, 0, 0);
|
||||||
case int32:
|
case mx::int32:
|
||||||
return to_list<int32_t>(a, 0, 0);
|
return to_list<int32_t>(a, 0, 0);
|
||||||
case int64:
|
case mx::int64:
|
||||||
return to_list<int64_t>(a, 0, 0);
|
return to_list<int64_t>(a, 0, 0);
|
||||||
case float16:
|
case mx::float16:
|
||||||
return to_list<float16_t, float>(a, 0, 0);
|
return to_list<mx::float16_t, float>(a, 0, 0);
|
||||||
case float32:
|
case mx::float32:
|
||||||
return to_list<float>(a, 0, 0);
|
return to_list<float>(a, 0, 0);
|
||||||
case bfloat16:
|
case mx::bfloat16:
|
||||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
return to_list<mx::bfloat16_t, float>(a, 0, 0);
|
||||||
case complex64:
|
case mx::complex64:
|
||||||
return to_list<std::complex<float>>(a, 0, 0);
|
return to_list<std::complex<float>>(a, 0, 0);
|
||||||
default:
|
default:
|
||||||
throw nb::type_error("data type cannot be converted to Python list.");
|
throw nb::type_error("data type cannot be converted to Python list.");
|
||||||
@ -279,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
PyScalarT validate_shape(
|
PyScalarT validate_shape(
|
||||||
T list,
|
T list,
|
||||||
const Shape& shape,
|
const mx::Shape& shape,
|
||||||
int idx,
|
int idx,
|
||||||
bool& all_python_primitive_elements) {
|
bool& all_python_primitive_elements) {
|
||||||
if (idx >= shape.size()) {
|
if (idx >= shape.size()) {
|
||||||
@ -307,9 +307,9 @@ PyScalarT validate_shape(
|
|||||||
shape,
|
shape,
|
||||||
idx + 1,
|
idx + 1,
|
||||||
all_python_primitive_elements);
|
all_python_primitive_elements);
|
||||||
} else if (nb::isinstance<array>(l)) {
|
} else if (nb::isinstance<mx::array>(l)) {
|
||||||
all_python_primitive_elements = false;
|
all_python_primitive_elements = false;
|
||||||
auto arr = nb::cast<array>(l);
|
auto arr = nb::cast<mx::array>(l);
|
||||||
if (arr.ndim() + idx + 1 == shape.size() &&
|
if (arr.ndim() + idx + 1 == shape.size() &&
|
||||||
std::equal(
|
std::equal(
|
||||||
arr.shape().cbegin(),
|
arr.shape().cbegin(),
|
||||||
@ -347,7 +347,7 @@ PyScalarT validate_shape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void get_shape(T list, Shape& shape) {
|
void get_shape(T list, mx::Shape& shape) {
|
||||||
shape.push_back(check_shape_dim(nb::len(list)));
|
shape.push_back(check_shape_dim(nb::len(list)));
|
||||||
if (shape.back() > 0) {
|
if (shape.back() > 0) {
|
||||||
auto l = list.begin();
|
auto l = list.begin();
|
||||||
@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) {
|
|||||||
return get_shape(nb::cast<nb::list>(*l), shape);
|
return get_shape(nb::cast<nb::list>(*l), shape);
|
||||||
} else if (nb::isinstance<nb::tuple>(*l)) {
|
} else if (nb::isinstance<nb::tuple>(*l)) {
|
||||||
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
||||||
} else if (nb::isinstance<array>(*l)) {
|
} else if (nb::isinstance<mx::array>(*l)) {
|
||||||
auto arr = nb::cast<array>(*l);
|
auto arr = nb::cast<mx::array>(*l);
|
||||||
for (int i = 0; i < arr.ndim(); i++) {
|
for (int i = 0; i < arr.ndim(); i++) {
|
||||||
shape.push_back(arr.shape(i));
|
shape.push_back(arr.shape(i));
|
||||||
}
|
}
|
||||||
@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array array_from_list_impl(
|
mx::array array_from_list_impl(
|
||||||
T pl,
|
T pl,
|
||||||
const PyScalarT& inferred_type,
|
const PyScalarT& inferred_type,
|
||||||
std::optional<Dtype> specified_type,
|
std::optional<mx::Dtype> specified_type,
|
||||||
const Shape& shape) {
|
const mx::Shape& shape) {
|
||||||
// Make the array
|
// Make the array
|
||||||
switch (inferred_type) {
|
switch (inferred_type) {
|
||||||
case pybool: {
|
case pybool: {
|
||||||
std::vector<bool> vals;
|
std::vector<bool> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_));
|
||||||
}
|
}
|
||||||
case pyint: {
|
case pyint: {
|
||||||
auto dtype = specified_type.value_or(int32);
|
auto dtype = specified_type.value_or(mx::int32);
|
||||||
if (dtype == int64) {
|
if (dtype == mx::int64) {
|
||||||
std::vector<int64_t> vals;
|
std::vector<int64_t> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return mx::array(vals.begin(), shape, dtype);
|
||||||
} else if (dtype == uint64) {
|
} else if (dtype == mx::uint64) {
|
||||||
std::vector<uint64_t> vals;
|
std::vector<uint64_t> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return mx::array(vals.begin(), shape, dtype);
|
||||||
} else if (dtype == uint32) {
|
} else if (dtype == mx::uint32) {
|
||||||
std::vector<uint32_t> vals;
|
std::vector<uint32_t> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return mx::array(vals.begin(), shape, dtype);
|
||||||
} else if (issubdtype(dtype, inexact)) {
|
} else if (mx::issubdtype(dtype, mx::inexact)) {
|
||||||
std::vector<float> vals;
|
std::vector<float> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return mx::array(vals.begin(), shape, dtype);
|
||||||
} else {
|
} else {
|
||||||
std::vector<int> vals;
|
std::vector<int> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, dtype);
|
return mx::array(vals.begin(), shape, dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case pyfloat: {
|
case pyfloat: {
|
||||||
std::vector<float> vals;
|
std::vector<float> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(vals.begin(), shape, specified_type.value_or(float32));
|
return mx::array(
|
||||||
|
vals.begin(), shape, specified_type.value_or(mx::float32));
|
||||||
}
|
}
|
||||||
case pycomplex: {
|
case pycomplex: {
|
||||||
std::vector<std::complex<float>> vals;
|
std::vector<std::complex<float>> vals;
|
||||||
fill_vector(pl, vals);
|
fill_vector(pl, vals);
|
||||||
return array(
|
return mx::array(
|
||||||
reinterpret_cast<complex64_t*>(vals.data()),
|
reinterpret_cast<mx::complex64_t*>(vals.data()),
|
||||||
shape,
|
shape,
|
||||||
specified_type.value_or(complex64));
|
specified_type.value_or(mx::complex64));
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -425,9 +426,9 @@ array array_from_list_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
|
||||||
// Compute the shape
|
// Compute the shape
|
||||||
Shape shape;
|
mx::Shape shape;
|
||||||
get_shape(pl, shape);
|
get_shape(pl, shape);
|
||||||
|
|
||||||
// Validate the shape and type
|
// Validate the shape and type
|
||||||
@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// `pl` contains mlx arrays
|
// `pl` contains mlx arrays
|
||||||
std::vector<array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
for (auto l : pl) {
|
for (auto l : pl) {
|
||||||
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
||||||
}
|
}
|
||||||
return stack(arrays);
|
return mx::stack(arrays);
|
||||||
}
|
}
|
||||||
|
|
||||||
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
|
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype) {
|
||||||
return array_from_list_impl(pl, dtype);
|
return array_from_list_impl(pl, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
|
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
|
||||||
return array_from_list_impl(pl, dtype);
|
return array_from_list_impl(pl, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||||
return array(nb::cast<bool>(*pv), t.value_or(bool_));
|
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
|
||||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||||
return array(nb::cast<int>(*pv), t.value_or(int32));
|
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
|
||||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||||
return array(nb::cast<float>(*pv), t.value_or(float32));
|
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
return mx::array(
|
||||||
|
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
|
||||||
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
||||||
return array_from_list(*pv, t);
|
return array_from_list(*pv, t);
|
||||||
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
||||||
@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
|||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||||
pv) {
|
pv) {
|
||||||
return nd_array_to_mlx(*pv, t);
|
return nd_array_to_mlx(*pv, t);
|
||||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||||
return astype(*pv, t.value_or((*pv).dtype()));
|
return mx::astype(*pv, t.value_or((*pv).dtype()));
|
||||||
} else {
|
} else {
|
||||||
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
||||||
return astype(arr, t.value_or(arr.dtype()));
|
return mx::astype(arr, t.value_or(arr.dtype()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,15 +9,15 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
using ArrayInitType = std::variant<
|
using ArrayInitType = std::variant<
|
||||||
nb::bool_,
|
nb::bool_,
|
||||||
nb::int_,
|
nb::int_,
|
||||||
nb::float_,
|
nb::float_,
|
||||||
// Must be above ndarray
|
// Must be above ndarray
|
||||||
array,
|
mx::array,
|
||||||
// Must be above complex
|
// Must be above complex
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||||
std::complex<float>,
|
std::complex<float>,
|
||||||
@ -25,17 +25,17 @@ using ArrayInitType = std::variant<
|
|||||||
nb::tuple,
|
nb::tuple,
|
||||||
nb::object>;
|
nb::object>;
|
||||||
|
|
||||||
array nd_array_to_mlx(
|
mx::array nd_array_to_mlx(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
std::optional<Dtype> dtype);
|
std::optional<mx::Dtype> dtype);
|
||||||
|
|
||||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
nb::ndarray<nb::numpy> mlx_to_np_array(const mx::array& a);
|
||||||
nb::ndarray<> mlx_to_dlpack(const array& a);
|
nb::ndarray<> mlx_to_dlpack(const mx::array& a);
|
||||||
|
|
||||||
nb::object to_scalar(array& a);
|
nb::object to_scalar(mx::array& a);
|
||||||
|
|
||||||
nb::object tolist(array& a);
|
nb::object tolist(mx::array& a);
|
||||||
|
|
||||||
array create_array(ArrayInitType v, std::optional<Dtype> t);
|
mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
|
||||||
array array_from_list(nb::list pl, std::optional<Dtype> dtype);
|
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
|
||||||
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype);
|
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);
|
||||||
|
@ -8,51 +8,54 @@
|
|||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void init_device(nb::module_& m) {
|
void init_device(nb::module_& m) {
|
||||||
auto device_class = nb::class_<Device>(
|
auto device_class = nb::class_<mx::Device>(
|
||||||
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
||||||
nb::enum_<Device::DeviceType>(m, "DeviceType")
|
nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
|
||||||
.value("cpu", Device::DeviceType::cpu)
|
.value("cpu", mx::Device::DeviceType::cpu)
|
||||||
.value("gpu", Device::DeviceType::gpu)
|
.value("gpu", mx::Device::DeviceType::gpu)
|
||||||
.export_values()
|
.export_values()
|
||||||
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
|
.def(
|
||||||
if (!nb::isinstance<Device>(other) &&
|
"__eq__",
|
||||||
!nb::isinstance<Device::DeviceType>(other)) {
|
[](const mx::Device::DeviceType& d, const nb::object& other) {
|
||||||
|
if (!nb::isinstance<mx::Device>(other) &&
|
||||||
|
!nb::isinstance<mx::Device::DeviceType>(other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return d == nb::cast<Device>(other);
|
return d == nb::cast<mx::Device>(other);
|
||||||
});
|
});
|
||||||
|
|
||||||
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
device_class
|
||||||
.def_ro("type", &Device::type)
|
.def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||||
|
.def_ro("type", &mx::Device::type)
|
||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
[](const Device& d) {
|
[](const mx::Device& d) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << d;
|
os << d;
|
||||||
return os.str();
|
return os.str();
|
||||||
})
|
})
|
||||||
.def("__eq__", [](const Device& d, const nb::object& other) {
|
.def("__eq__", [](const mx::Device& d, const nb::object& other) {
|
||||||
if (!nb::isinstance<Device>(other) &&
|
if (!nb::isinstance<mx::Device>(other) &&
|
||||||
!nb::isinstance<Device::DeviceType>(other)) {
|
!nb::isinstance<mx::Device::DeviceType>(other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return d == nb::cast<Device>(other);
|
return d == nb::cast<mx::Device>(other);
|
||||||
});
|
});
|
||||||
|
|
||||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"default_device",
|
"default_device",
|
||||||
&default_device,
|
&mx::default_device,
|
||||||
R"pbdoc(Get the default device.)pbdoc");
|
R"pbdoc(Get the default device.)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"set_default_device",
|
"set_default_device",
|
||||||
&set_default_device,
|
&mx::set_default_device,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Set the default device.)pbdoc");
|
R"pbdoc(Set the default device.)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -9,26 +9,27 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void init_distributed(nb::module_& parent_module) {
|
void init_distributed(nb::module_& parent_module) {
|
||||||
auto m = parent_module.def_submodule(
|
auto m = parent_module.def_submodule(
|
||||||
"distributed", "mlx.core.distributed: Communication operations");
|
"distributed", "mlx.core.distributed: Communication operations");
|
||||||
|
|
||||||
nb::class_<distributed::Group>(
|
nb::class_<mx::distributed::Group>(
|
||||||
m,
|
m,
|
||||||
"Group",
|
"Group",
|
||||||
R"pbcopy(
|
R"pbcopy(
|
||||||
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
||||||
processes that can communicate.
|
processes that can communicate.
|
||||||
)pbcopy")
|
)pbcopy")
|
||||||
.def("rank", &distributed::Group::rank, "Get the rank of this process")
|
.def(
|
||||||
.def("size", &distributed::Group::size, "Get the size of the group")
|
"rank", &mx::distributed::Group::rank, "Get the rank of this process")
|
||||||
|
.def("size", &mx::distributed::Group::size, "Get the size of the group")
|
||||||
.def(
|
.def(
|
||||||
"split",
|
"split",
|
||||||
&distributed::Group::split,
|
&mx::distributed::Group::split,
|
||||||
"color"_a,
|
"color"_a,
|
||||||
"key"_a = -1,
|
"key"_a = -1,
|
||||||
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
||||||
@ -48,14 +49,14 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"is_available",
|
"is_available",
|
||||||
&distributed::is_available,
|
&mx::distributed::is_available,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if a communication backend is available.
|
Check if a communication backend is available.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"init",
|
"init",
|
||||||
&distributed::init,
|
&mx::distributed::init,
|
||||||
"strict"_a = false,
|
"strict"_a = false,
|
||||||
nb::sig("def init(strict: bool = False) -> Group"),
|
nb::sig("def init(strict: bool = False) -> Group"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -72,7 +73,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_sum",
|
"all_sum",
|
||||||
&distributed::all_sum,
|
&mx::distributed::all_sum,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
@ -98,7 +99,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_gather",
|
"all_gather",
|
||||||
&distributed::all_gather,
|
&mx::distributed::all_gather,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
@ -125,7 +126,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"send",
|
"send",
|
||||||
&distributed::send,
|
&mx::distributed::send,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"dst"_a,
|
"dst"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -152,7 +153,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"recv",
|
"recv",
|
||||||
&distributed::recv,
|
&mx::distributed::recv,
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
"dtype"_a,
|
"dtype"_a,
|
||||||
"src"_a,
|
"src"_a,
|
||||||
@ -181,7 +182,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"recv_like",
|
"recv_like",
|
||||||
&distributed::recv_like,
|
&mx::distributed::recv_like,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"src"_a,
|
"src"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
@ -13,9 +13,9 @@
|
|||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void init_fast(nb::module_& parent_module) {
|
void init_fast(nb::module_& parent_module) {
|
||||||
auto m =
|
auto m =
|
||||||
@ -23,7 +23,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rms_norm",
|
"rms_norm",
|
||||||
&fast::rms_norm,
|
&mx::fast::rms_norm,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"weight"_a,
|
"weight"_a,
|
||||||
"eps"_a,
|
"eps"_a,
|
||||||
@ -49,7 +49,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"layer_norm",
|
"layer_norm",
|
||||||
&fast::layer_norm,
|
&mx::fast::layer_norm,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"weight"_a.none(),
|
"weight"_a.none(),
|
||||||
"bias"_a.none(),
|
"bias"_a.none(),
|
||||||
@ -79,7 +79,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rope",
|
"rope",
|
||||||
&fast::rope,
|
&mx::fast::rope,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"dims"_a,
|
"dims"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -114,7 +114,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"scaled_dot_product_attention",
|
"scaled_dot_product_attention",
|
||||||
&fast::scaled_dot_product_attention,
|
&mx::fast::scaled_dot_product_attention,
|
||||||
"q"_a,
|
"q"_a,
|
||||||
"k"_a,
|
"k"_a,
|
||||||
"v"_a,
|
"v"_a,
|
||||||
@ -170,7 +170,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
const std::string& header,
|
const std::string& header,
|
||||||
bool ensure_row_contiguous,
|
bool ensure_row_contiguous,
|
||||||
bool atomic_outputs) {
|
bool atomic_outputs) {
|
||||||
auto kernel = fast::metal_kernel(
|
auto kernel = mx::fast::metal_kernel(
|
||||||
name,
|
name,
|
||||||
input_names,
|
input_names,
|
||||||
output_names,
|
output_names,
|
||||||
@ -182,7 +182,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
[kernel = std::move(kernel)](
|
[kernel = std::move(kernel)](
|
||||||
const std::vector<ScalarOrArray>& inputs_,
|
const std::vector<ScalarOrArray>& inputs_,
|
||||||
const std::vector<std::vector<int>>& output_shapes,
|
const std::vector<std::vector<int>>& output_shapes,
|
||||||
const std::vector<Dtype>& output_dtypes,
|
const std::vector<mx::Dtype>& output_dtypes,
|
||||||
std::tuple<int, int, int> grid,
|
std::tuple<int, int, int> grid,
|
||||||
std::tuple<int, int, int> threadgroup,
|
std::tuple<int, int, int> threadgroup,
|
||||||
const std::optional<
|
const std::optional<
|
||||||
@ -190,12 +190,12 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
template_args_ = std::nullopt,
|
template_args_ = std::nullopt,
|
||||||
std::optional<float> init_value = std::nullopt,
|
std::optional<float> init_value = std::nullopt,
|
||||||
bool verbose = false,
|
bool verbose = false,
|
||||||
StreamOrDevice s = {}) {
|
mx::StreamOrDevice s = {}) {
|
||||||
std::vector<array> inputs;
|
std::vector<mx::array> inputs;
|
||||||
for (const auto& value : inputs_) {
|
for (const auto& value : inputs_) {
|
||||||
inputs.push_back(to_array(value, std::nullopt));
|
inputs.push_back(to_array(value, std::nullopt));
|
||||||
}
|
}
|
||||||
std::vector<std::pair<std::string, fast::TemplateArg>>
|
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
|
||||||
template_args;
|
template_args;
|
||||||
if (template_args_) {
|
if (template_args_) {
|
||||||
for (const auto& [name, value] : template_args_.value()) {
|
for (const auto& [name, value] : template_args_.value()) {
|
||||||
@ -206,8 +206,8 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
} else if (nb::isinstance<int>(value)) {
|
} else if (nb::isinstance<int>(value)) {
|
||||||
int int_val = nb::cast<int>(value);
|
int int_val = nb::cast<int>(value);
|
||||||
template_args.emplace_back(name, int_val);
|
template_args.emplace_back(name, int_val);
|
||||||
} else if (nb::isinstance<Dtype>(value)) {
|
} else if (nb::isinstance<mx::Dtype>(value)) {
|
||||||
Dtype dtype = nb::cast<Dtype>(value);
|
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
|
||||||
template_args.emplace_back(name, dtype);
|
template_args.emplace_back(name, dtype);
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -9,24 +9,23 @@
|
|||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void init_fft(nb::module_& parent_module) {
|
void init_fft(nb::module_& parent_module) {
|
||||||
auto m = parent_module.def_submodule(
|
auto m = parent_module.def_submodule(
|
||||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||||
m.def(
|
m.def(
|
||||||
"fft",
|
"fft",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<int>& n,
|
const std::optional<int>& n,
|
||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (n.has_value()) {
|
if (n.has_value()) {
|
||||||
return fft::fft(a, n.value(), axis, s);
|
return mx::fft::fft(a, n.value(), axis, s);
|
||||||
} else {
|
} else {
|
||||||
return fft::fft(a, axis, s);
|
return mx::fft::fft(a, axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"ifft",
|
"ifft",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<int>& n,
|
const std::optional<int>& n,
|
||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (n.has_value()) {
|
if (n.has_value()) {
|
||||||
return fft::ifft(a, n.value(), axis, s);
|
return mx::fft::ifft(a, n.value(), axis, s);
|
||||||
} else {
|
} else {
|
||||||
return fft::ifft(a, axis, s);
|
return mx::fft::ifft(a, axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"fft2",
|
"fft2",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::fftn(a, n.value(), axes.value(), s);
|
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::fftn(a, axes.value(), s);
|
return mx::fft::fftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::fftn(a, s);
|
return mx::fft::fftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"ifft2",
|
"ifft2",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::ifftn(a, axes.value(), s);
|
return mx::fft::ifftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::ifftn(a, s);
|
return mx::fft::ifftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"fftn",
|
"fftn",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::fftn(a, n.value(), axes.value(), s);
|
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::fftn(a, axes.value(), s);
|
return mx::fft::fftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::fftn(a, s);
|
return mx::fft::fftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"ifftn",
|
"ifftn",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::ifftn(a, n.value(), axes.value(), s);
|
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::ifftn(a, axes.value(), s);
|
return mx::fft::ifftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::ifftn(a, s);
|
return mx::fft::ifftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"rfft",
|
"rfft",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<int>& n,
|
const std::optional<int>& n,
|
||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (n.has_value()) {
|
if (n.has_value()) {
|
||||||
return fft::rfft(a, n.value(), axis, s);
|
return mx::fft::rfft(a, n.value(), axis, s);
|
||||||
} else {
|
} else {
|
||||||
return fft::rfft(a, axis, s);
|
return mx::fft::rfft(a, axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"irfft",
|
"irfft",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<int>& n,
|
const std::optional<int>& n,
|
||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (n.has_value()) {
|
if (n.has_value()) {
|
||||||
return fft::irfft(a, n.value(), axis, s);
|
return mx::fft::irfft(a, n.value(), axis, s);
|
||||||
} else {
|
} else {
|
||||||
return fft::irfft(a, axis, s);
|
return mx::fft::irfft(a, axis, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"rfft2",
|
"rfft2",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::rfftn(a, axes.value(), s);
|
return mx::fft::rfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::rfftn(a, s);
|
return mx::fft::rfftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"irfft2",
|
"irfft2",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::irfftn(a, axes.value(), s);
|
return mx::fft::irfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::irfftn(a, s);
|
return mx::fft::irfftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"rfftn",
|
"rfftn",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::rfftn(a, n.value(), axes.value(), s);
|
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::rfftn(a, axes.value(), s);
|
return mx::fft::rfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::rfftn(a, s);
|
return mx::fft::rfftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"irfftn",
|
"irfftn",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::optional<std::vector<int>>& n,
|
const std::optional<std::vector<int>>& n,
|
||||||
const std::optional<std::vector<int>>& axes,
|
const std::optional<std::vector<int>>& axes,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (axes.has_value() && n.has_value()) {
|
if (axes.has_value() && n.has_value()) {
|
||||||
return fft::irfftn(a, n.value(), axes.value(), s);
|
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
||||||
} else if (axes.has_value()) {
|
} else if (axes.has_value()) {
|
||||||
return fft::irfftn(a, axes.value(), s);
|
return mx::fft::irfftn(a, axes.value(), s);
|
||||||
} else if (n.has_value()) {
|
} else if (n.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
||||||
} else {
|
} else {
|
||||||
return fft::irfftn(a, s);
|
return mx::fft::irfftn(a, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
@ -43,20 +43,20 @@ void get_slice_params(
|
|||||||
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
array get_int_index(nb::object idx, int axis_size) {
|
mx::array get_int_index(nb::object idx, int axis_size) {
|
||||||
int idx_ = nb::cast<int>(idx);
|
int idx_ = nb::cast<int>(idx);
|
||||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||||
|
|
||||||
return array(idx_, uint32);
|
return mx::array(idx_, mx::uint32);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_valid_index_type(const nb::object& obj) {
|
bool is_valid_index_type(const nb::object& obj) {
|
||||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
|
nb::isinstance<mx::array>(obj) || obj.is_none() ||
|
||||||
nb::isinstance<nb::list>(obj);
|
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
|
||||||
// Check input and raise error if 0 dim for parity with np
|
// Check input and raise error if 0 dim for parity with np
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -77,14 +77,14 @@ array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
|||||||
return slice(src, starts, ends, strides);
|
return slice(src, starts, ends, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item_array(const array& src, const array& indices) {
|
mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
|
||||||
// Check input and raise error if 0 dim for parity with np
|
// Check input and raise error if 0 dim for parity with np
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"too many indices for array: array is 0-dimensional");
|
"too many indices for array: array is 0-dimensional");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (indices.dtype() == bool_) {
|
if (indices.dtype() == mx::bool_) {
|
||||||
throw std::invalid_argument("boolean indices are not yet supported");
|
throw std::invalid_argument("boolean indices are not yet supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
|
|||||||
return take(src, indices, 0);
|
return take(src, indices, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
|
||||||
// Check input and raise error if 0 dim for parity with np
|
// Check input and raise error if 0 dim for parity with np
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -105,13 +105,13 @@ array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
|||||||
return take(src, get_int_index(idx, src.shape(0)), 0);
|
return take(src, get_int_index(idx, src.shape(0)), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_gather_nd(
|
mx::array mlx_gather_nd(
|
||||||
array src,
|
mx::array src,
|
||||||
const std::vector<nb::object>& indices,
|
const std::vector<nb::object>& indices,
|
||||||
bool gather_first,
|
bool gather_first,
|
||||||
int& max_dims) {
|
int& max_dims) {
|
||||||
max_dims = 0;
|
max_dims = 0;
|
||||||
std::vector<array> gather_indices;
|
std::vector<mx::array> gather_indices;
|
||||||
std::vector<bool> is_slice(indices.size(), false);
|
std::vector<bool> is_slice(indices.size(), false);
|
||||||
int num_slices = 0;
|
int num_slices = 0;
|
||||||
// gather all the arrays
|
// gather all the arrays
|
||||||
@ -127,13 +127,13 @@ array mlx_gather_nd(
|
|||||||
start = (start < 0) ? start + src.shape(i) : start;
|
start = (start < 0) ? start + src.shape(i) : start;
|
||||||
end = (end < 0) ? end + src.shape(i) : end;
|
end = (end < 0) ? end + src.shape(i) : end;
|
||||||
|
|
||||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
gather_indices.push_back(arange(start, end, stride, mx::uint32));
|
||||||
num_slices++;
|
num_slices++;
|
||||||
is_slice[i] = true;
|
is_slice[i] = true;
|
||||||
} else if (nb::isinstance<nb::int_>(idx)) {
|
} else if (nb::isinstance<nb::int_>(idx)) {
|
||||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||||
} else if (nb::isinstance<array>(idx)) {
|
} else if (nb::isinstance<mx::array>(idx)) {
|
||||||
auto arr = nb::cast<array>(idx);
|
auto arr = nb::cast<mx::array>(idx);
|
||||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||||
gather_indices.push_back(arr);
|
gather_indices.push_back(arr);
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ array mlx_gather_nd(
|
|||||||
int slice_index = 0;
|
int slice_index = 0;
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < gather_indices.size(); i++) {
|
||||||
if (is_slice[i]) {
|
if (is_slice[i]) {
|
||||||
Shape index_shape(max_dims + num_slices, 1);
|
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||||
slice_index++;
|
slice_index++;
|
||||||
@ -158,7 +158,7 @@ array mlx_gather_nd(
|
|||||||
// reshape them so that the int/array indices are last
|
// reshape them so that the int/array indices are last
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < gather_indices.size(); i++) {
|
||||||
if (i < num_slices) {
|
if (i < num_slices) {
|
||||||
Shape index_shape(max_dims + num_slices, 1);
|
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[i] = gather_indices[i].shape(0);
|
index_shape[i] = gather_indices[i].shape(0);
|
||||||
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||||
}
|
}
|
||||||
@ -241,7 +241,7 @@ auto mlx_expand_ellipsis(
|
|||||||
return std::make_pair(non_none_indices, indices);
|
return std::make_pair(non_none_indices, indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||||
// No indices make this a noop
|
// No indices make this a noop
|
||||||
if (entries.size() == 0) {
|
if (entries.size() == 0) {
|
||||||
return src;
|
return src;
|
||||||
@ -281,7 +281,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
bool have_non_array = false;
|
bool have_non_array = false;
|
||||||
bool gather_first = false;
|
bool gather_first = false;
|
||||||
for (auto& idx : indices) {
|
for (auto& idx : indices) {
|
||||||
if (nb::isinstance<array>(idx) || (nb::isinstance<nb::int_>(idx))) {
|
if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
|
||||||
if (have_array && have_non_array) {
|
if (have_array && have_non_array) {
|
||||||
gather_first = true;
|
gather_first = true;
|
||||||
break;
|
break;
|
||||||
@ -294,7 +294,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
|
|
||||||
int n_arr = 0;
|
int n_arr = 0;
|
||||||
for (auto& idx : indices) {
|
for (auto& idx : indices) {
|
||||||
n_arr += nb::isinstance<array>(idx);
|
n_arr += nb::isinstance<mx::array>(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
have_array &= n_arr > 0;
|
have_array &= n_arr > 0;
|
||||||
@ -304,7 +304,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
// Then find the last array
|
// Then find the last array
|
||||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||||
auto& idx = indices[last_array];
|
auto& idx = indices[last_array];
|
||||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -340,7 +340,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < indices.size(); i++) {
|
for (int i = 0; i < indices.size(); i++) {
|
||||||
auto& idx = indices[i];
|
auto& idx = indices[i];
|
||||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||||
break;
|
break;
|
||||||
} else if (idx.is_none()) {
|
} else if (idx.is_none()) {
|
||||||
remaining_indices.push_back(idx);
|
remaining_indices.push_back(idx);
|
||||||
@ -426,11 +426,11 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item(const array& src, const nb::object& obj) {
|
mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
|
||||||
if (nb::isinstance<nb::slice>(obj)) {
|
if (nb::isinstance<nb::slice>(obj)) {
|
||||||
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
|
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
|
||||||
} else if (nb::isinstance<array>(obj)) {
|
} else if (nb::isinstance<mx::array>(obj)) {
|
||||||
return mlx_get_item_array(src, nb::cast<array>(obj));
|
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
|
||||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||||
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
|
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
|
||||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||||
@ -448,10 +448,11 @@ array mlx_get_item(const array& src, const nb::object& obj) {
|
|||||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||||
const array& src,
|
mlx_scatter_args_int(
|
||||||
|
const mx::array& src,
|
||||||
const nb::int_& idx,
|
const nb::int_& idx,
|
||||||
const array& update) {
|
const mx::array& update) {
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"too many indices for array: array is 0-dimensional");
|
"too many indices for array: array is 0-dimensional");
|
||||||
@ -473,10 +474,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
|||||||
{0}};
|
{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||||
const array& src,
|
mlx_scatter_args_array(
|
||||||
const array& indices,
|
const mx::array& src,
|
||||||
const array& update) {
|
const mx::array& indices,
|
||||||
|
const mx::array& update) {
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"too many indices for array: array is 0-dimensional");
|
"too many indices for array: array is 0-dimensional");
|
||||||
@ -500,10 +502,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
|||||||
return {{indices}, up, {0}};
|
return {{indices}, up, {0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||||
const array& src,
|
mlx_scatter_args_slice(
|
||||||
|
const mx::array& src,
|
||||||
const nb::slice& in_slice,
|
const nb::slice& in_slice,
|
||||||
const array& update) {
|
const mx::array& update) {
|
||||||
// Check input and raise error if 0 dim for parity with np
|
// Check input and raise error if 0 dim for parity with np
|
||||||
if (src.ndim() == 0) {
|
if (src.ndim() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -539,7 +542,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
|||||||
auto up = reshape(update, up_shape);
|
auto up = reshape(update, up_shape);
|
||||||
|
|
||||||
// Build array to mark start of slice
|
// Build array to mark start of slice
|
||||||
auto idx = array({start}, {1}, uint32);
|
auto idx = mx::array({start}, {1}, mx::uint32);
|
||||||
|
|
||||||
// Get slice size
|
// Get slice size
|
||||||
int slice_size = (end - start);
|
int slice_size = (end - start);
|
||||||
@ -551,20 +554,21 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
|||||||
|
|
||||||
up = broadcast_to(up, up_shape_broadcast);
|
up = broadcast_to(up, up_shape_broadcast);
|
||||||
|
|
||||||
auto indices = std::vector<array>{idx};
|
auto indices = std::vector<mx::array>{idx};
|
||||||
auto axes = std::vector<int>{0};
|
auto axes = std::vector<int>{0};
|
||||||
|
|
||||||
return {indices, up, axes};
|
return {indices, up, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
return mlx_scatter_args_array(
|
return mlx_scatter_args_array(
|
||||||
src, arange(start, end, stride, uint32), update);
|
src, arange(start, end, stride, mx::uint32), update);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||||
const array& src,
|
mlx_scatter_args_nd(
|
||||||
|
const mx::array& src,
|
||||||
const nb::tuple& entries,
|
const nb::tuple& entries,
|
||||||
const array& update) {
|
const mx::array& update) {
|
||||||
// Expand ellipses into a series of ':' slices
|
// Expand ellipses into a series of ':' slices
|
||||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||||
|
|
||||||
@ -623,12 +627,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
num_simple_slices_post++;
|
num_simple_slices_post++;
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (nb::isinstance<array>(idx)) {
|
} else if (nb::isinstance<mx::array>(idx)) {
|
||||||
have_array = true;
|
have_array = true;
|
||||||
if (have_array && have_non_array) {
|
if (have_array && have_non_array) {
|
||||||
arrays_first = true;
|
arrays_first = true;
|
||||||
}
|
}
|
||||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
|
||||||
num_arrays++;
|
num_arrays++;
|
||||||
num_simple_slices_post = 0;
|
num_simple_slices_post = 0;
|
||||||
}
|
}
|
||||||
@ -643,7 +647,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
|
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
|
||||||
|
|
||||||
// Go over each index type and translate to the needed scatter args
|
// Go over each index type and translate to the needed scatter args
|
||||||
std::vector<array> arr_indices;
|
std::vector<mx::array> arr_indices;
|
||||||
int slice_num = 0;
|
int slice_num = 0;
|
||||||
int array_num = 0;
|
int array_num = 0;
|
||||||
int ax = 0;
|
int ax = 0;
|
||||||
@ -668,7 +672,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
|
|
||||||
// If it's a simple slice, we only need to add the start index
|
// If it's a simple slice, we only need to add the start index
|
||||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||||
auto idx = array({start}, idx_shape, uint32);
|
auto idx = mx::array({start}, idx_shape, mx::uint32);
|
||||||
slice_shapes.push_back(end - start);
|
slice_shapes.push_back(end - start);
|
||||||
arr_indices.push_back(idx);
|
arr_indices.push_back(idx);
|
||||||
|
|
||||||
@ -677,7 +681,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
}
|
}
|
||||||
// Otherwise we expand the slice into indices using arange
|
// Otherwise we expand the slice into indices using arange
|
||||||
else {
|
else {
|
||||||
auto idx = arange(start, end, stride, uint32);
|
auto idx = arange(start, end, stride, mx::uint32);
|
||||||
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
auto loc = slice_num + (arrays_first ? max_dim : 0);
|
||||||
idx_shape[loc] = idx.size();
|
idx_shape[loc] = idx.size();
|
||||||
arr_indices.push_back(reshape(idx, idx_shape));
|
arr_indices.push_back(reshape(idx, idx_shape));
|
||||||
@ -696,9 +700,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
} else if (pyidx.is_none()) {
|
} else if (pyidx.is_none()) {
|
||||||
// We only use the None's for bookeeping dimensions
|
// We only use the None's for bookeeping dimensions
|
||||||
slice_num++;
|
slice_num++;
|
||||||
} else if (nb::isinstance<array>(pyidx)) {
|
} else if (nb::isinstance<mx::array>(pyidx)) {
|
||||||
ax++;
|
ax++;
|
||||||
auto idx = nb::cast<array>(pyidx);
|
auto idx = nb::cast<mx::array>(pyidx);
|
||||||
std::vector<int> idx_shape(idx_ndim, 1);
|
std::vector<int> idx_shape(idx_ndim, 1);
|
||||||
|
|
||||||
// Place the arrays in the correct dimension
|
// Place the arrays in the correct dimension
|
||||||
@ -748,16 +752,16 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
return {arr_indices, up, axes};
|
return {arr_indices, up, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<array>, array, std::vector<int>>
|
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||||
mlx_compute_scatter_args(
|
mlx_compute_scatter_args(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto vals = to_array(v, src.dtype());
|
auto vals = to_array(v, src.dtype());
|
||||||
if (nb::isinstance<nb::slice>(obj)) {
|
if (nb::isinstance<nb::slice>(obj)) {
|
||||||
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
|
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
|
||||||
} else if (nb::isinstance<array>(obj)) {
|
} else if (nb::isinstance<mx::array>(obj)) {
|
||||||
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
|
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
|
||||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||||
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
|
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
|
||||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||||
@ -773,7 +777,7 @@ mlx_compute_scatter_args(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto mlx_slice_update(
|
auto mlx_slice_update(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
// Can't route to slice update if not slice or tuple
|
// Can't route to slice update if not slice or tuple
|
||||||
@ -784,7 +788,7 @@ auto mlx_slice_update(
|
|||||||
if (nb::isinstance<nb::tuple>(obj)) {
|
if (nb::isinstance<nb::tuple>(obj)) {
|
||||||
// Can't route to slice update if any arrays are present
|
// Can't route to slice update if any arrays are present
|
||||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
|
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
|
||||||
return std::make_pair(false, src);
|
return std::make_pair(false, src);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -881,7 +885,10 @@ auto mlx_slice_update(
|
|||||||
return std::make_pair(true, out);
|
return std::make_pair(true, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
void mlx_set_item(
|
||||||
|
mx::array& src,
|
||||||
|
const nb::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
auto [success, out] = mlx_slice_update(src, obj, v);
|
auto [success, out] = mlx_slice_update(src, obj, v);
|
||||||
if (success) {
|
if (success) {
|
||||||
src.overwrite_descriptor(out);
|
src.overwrite_descriptor(out);
|
||||||
@ -897,8 +904,8 @@ void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_add_item(
|
mx::array mlx_add_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
@ -909,8 +916,8 @@ array mlx_add_item(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_subtract_item(
|
mx::array mlx_subtract_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
@ -921,8 +928,8 @@ array mlx_subtract_item(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_multiply_item(
|
mx::array mlx_multiply_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
@ -933,8 +940,8 @@ array mlx_multiply_item(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_divide_item(
|
mx::array mlx_divide_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
@ -945,8 +952,8 @@ array mlx_divide_item(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_maximum_item(
|
mx::array mlx_maximum_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
@ -957,8 +964,8 @@ array mlx_maximum_item(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_minimum_item(
|
mx::array mlx_minimum_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
@ -7,32 +7,35 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
array mlx_get_item(const array& src, const nb::object& obj);
|
mx::array mlx_get_item(const mx::array& src, const nb::object& obj);
|
||||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
|
void mlx_set_item(
|
||||||
array mlx_add_item(
|
mx::array& src,
|
||||||
const array& src,
|
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
array mlx_subtract_item(
|
mx::array mlx_add_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
array mlx_multiply_item(
|
mx::array mlx_subtract_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
array mlx_divide_item(
|
mx::array mlx_multiply_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
array mlx_maximum_item(
|
mx::array mlx_divide_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
array mlx_minimum_item(
|
mx::array mlx_maximum_item(
|
||||||
const array& src,
|
const mx::array& src,
|
||||||
|
const nb::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
mx::array mlx_minimum_item(
|
||||||
|
const mx::array& src,
|
||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v);
|
const ScalarOrArray& v);
|
||||||
|
@ -10,15 +10,13 @@
|
|||||||
|
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
using namespace mlx::core::linalg;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||||
const auto result = svd(a, s);
|
const auto result = mx::linalg::svd(a, s);
|
||||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"norm",
|
"norm",
|
||||||
[](const array& a,
|
[](const mx::array& a,
|
||||||
const std::variant<std::monostate, int, double, std::string>& ord_,
|
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||||
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||||
const bool keepdims,
|
const bool keepdims,
|
||||||
const StreamOrDevice stream) {
|
const mx::StreamOrDevice stream) {
|
||||||
std::optional<std::vector<int>> axis = std::nullopt;
|
std::optional<std::vector<int>> axis = std::nullopt;
|
||||||
if (auto pv = std::get_if<int>(&axis_); pv) {
|
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||||
axis = std::vector<int>{*pv};
|
axis = std::vector<int>{*pv};
|
||||||
@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (std::holds_alternative<std::monostate>(ord_)) {
|
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||||
return norm(a, axis, keepdims, stream);
|
return mx::linalg::norm(a, axis, keepdims, stream);
|
||||||
} else {
|
} else {
|
||||||
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||||
return norm(a, *pv, axis, keepdims, stream);
|
return mx::linalg::norm(a, *pv, axis, keepdims, stream);
|
||||||
}
|
}
|
||||||
double ord;
|
double ord;
|
||||||
if (auto pv = std::get_if<int>(&ord_); pv) {
|
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||||
@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
} else {
|
} else {
|
||||||
ord = std::get<double>(ord_);
|
ord = std::get<double>(ord_);
|
||||||
}
|
}
|
||||||
return norm(a, ord, axis, keepdims, stream);
|
return mx::linalg::norm(a, ord, axis, keepdims, stream);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"qr",
|
"qr",
|
||||||
&qr,
|
&mx::linalg::qr,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"inv",
|
"inv",
|
||||||
&inv,
|
&mx::linalg::inv,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tri_inv",
|
"tri_inv",
|
||||||
&tri_inv,
|
&mx::linalg::tri_inv,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"upper"_a,
|
"upper"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cholesky",
|
"cholesky",
|
||||||
&cholesky,
|
&mx::linalg::cholesky,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"upper"_a = false,
|
"upper"_a = false,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cholesky_inv",
|
"cholesky_inv",
|
||||||
&cholesky_inv,
|
&mx::linalg::cholesky_inv,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"upper"_a = false,
|
"upper"_a = false,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"pinv",
|
"pinv",
|
||||||
&pinv,
|
&mx::linalg::pinv,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"cross",
|
"cross",
|
||||||
&cross,
|
&mx::linalg::cross,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"b"_a,
|
"b"_a,
|
||||||
"axis"_a = -1,
|
"axis"_a = -1,
|
||||||
@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"eigvalsh",
|
"eigvalsh",
|
||||||
&eigvalsh,
|
&mx::linalg::eigvalsh,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"UPLO"_a = "L",
|
"UPLO"_a = "L",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"eigh",
|
"eigh",
|
||||||
[](const array& a, const std::string UPLO, StreamOrDevice s) {
|
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||||
// TODO avoid cast?
|
// TODO avoid cast?
|
||||||
auto result = eigh(a, UPLO, s);
|
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||||
return nb::make_tuple(result.first, result.second);
|
return nb::make_tuple(result.first, result.second);
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
@ -14,9 +14,9 @@
|
|||||||
#include "python/src/load.h"
|
#include "python/src/load.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Helpers
|
// Helpers
|
||||||
@ -86,7 +86,7 @@ class ZipFileWrapper {
|
|||||||
// Loading
|
// Loading
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
class PyFileReader : public io::Reader {
|
class PyFileReader : public mx::io::Reader {
|
||||||
public:
|
public:
|
||||||
PyFileReader(nb::object file)
|
PyFileReader(nb::object file)
|
||||||
: pyistream_(file),
|
: pyistream_(file),
|
||||||
@ -168,14 +168,14 @@ class PyFileReader : public io::Reader {
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::pair<
|
std::pair<
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, mx::array>,
|
||||||
std::unordered_map<std::string, std::string>>
|
std::unordered_map<std::string, std::string>>
|
||||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) {
|
||||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||||
return load_safetensors(nb::cast<std::string>(file), s);
|
return mx::load_safetensors(nb::cast<std::string>(file), s);
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release gil;
|
nb::gil_scoped_release gil;
|
||||||
for (auto& [key, arr] : std::get<0>(res)) {
|
for (auto& [key, arr] : std::get<0>(res)) {
|
||||||
@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
|||||||
"[load_safetensors] Input must be a file-like object, or string");
|
"[load_safetensors] Input must be a file-like object, or string");
|
||||||
}
|
}
|
||||||
|
|
||||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) {
|
||||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||||
return load_gguf(nb::cast<std::string>(file), s);
|
return mx::load_gguf(nb::cast<std::string>(file), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
std::unordered_map<std::string, mx::array> mlx_load_npz_helper(
|
||||||
nb::object file,
|
nb::object file,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
bool own_file = nb::isinstance<nb::str>(file);
|
bool own_file = nb::isinstance<nb::str>(file);
|
||||||
|
|
||||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||||
@ -209,7 +209,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
|||||||
"opened with zipfile.ZipFile");
|
"opened with zipfile.ZipFile");
|
||||||
}
|
}
|
||||||
// Output dictionary filename in zip -> loaded array
|
// Output dictionary filename in zip -> loaded array
|
||||||
std::unordered_map<std::string, array> array_dict;
|
std::unordered_map<std::string, mx::array> array_dict;
|
||||||
|
|
||||||
// Create python ZipFile object
|
// Create python ZipFile object
|
||||||
ZipFileWrapper zipfile_object(zipfile, file);
|
ZipFileWrapper zipfile_object(zipfile, file);
|
||||||
@ -218,7 +218,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
|||||||
nb::object sub_file = zipfile_object.open(st);
|
nb::object sub_file = zipfile_object.open(st);
|
||||||
|
|
||||||
// Create array from python file stream
|
// Create array from python file stream
|
||||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s);
|
||||||
|
|
||||||
// Remove .npy from file if it is there
|
// Remove .npy from file if it is there
|
||||||
auto key = st;
|
auto key = st;
|
||||||
@ -240,12 +240,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
|||||||
return array_dict;
|
return array_dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
|
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) {
|
||||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||||
return load(nb::cast<std::string>(file), s);
|
return mx::load(nb::cast<std::string>(file), s);
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
auto arr = mx::load(std::make_shared<PyFileReader>(file), s);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release gil;
|
nb::gil_scoped_release gil;
|
||||||
arr.eval();
|
arr.eval();
|
||||||
@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper(
|
|||||||
nb::object file,
|
nb::object file,
|
||||||
std::optional<std::string> format,
|
std::optional<std::string> format,
|
||||||
bool return_metadata,
|
bool return_metadata,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
if (!format.has_value()) {
|
if (!format.has_value()) {
|
||||||
std::string fname;
|
std::string fname;
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper(
|
|||||||
// Saving
|
// Saving
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
class PyFileWriter : public io::Writer {
|
class PyFileWriter : public mx::io::Writer {
|
||||||
public:
|
public:
|
||||||
PyFileWriter(nb::object file)
|
PyFileWriter(nb::object file)
|
||||||
: pyostream_(file),
|
: pyostream_(file),
|
||||||
@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer {
|
|||||||
nb::object tell_func_;
|
nb::object tell_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void mlx_save_helper(nb::object file, array a) {
|
void mlx_save_helper(nb::object file, mx::array a) {
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
save(nb::cast<std::string>(file), a);
|
mx::save(nb::cast<std::string>(file), a);
|
||||||
return;
|
return;
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release gil;
|
nb::gil_scoped_release gil;
|
||||||
save(writer, a);
|
mx::save(writer, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -419,8 +419,9 @@ void mlx_savez_helper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Collect args and kwargs
|
// Collect args and kwargs
|
||||||
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
|
auto arrays_dict =
|
||||||
auto arrays_list = nb::cast<std::vector<array>>(args);
|
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
|
||||||
|
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
|
||||||
|
|
||||||
for (int i = 0; i < arrays_list.size(); i++) {
|
for (int i = 0; i < arrays_list.size(); i++) {
|
||||||
std::string arr_name = "arr_" + std::to_string(i);
|
std::string arr_name = "arr_" + std::to_string(i);
|
||||||
@ -447,7 +448,7 @@ void mlx_savez_helper(
|
|||||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
save(writer, a);
|
mx::save(writer, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -470,17 +471,18 @@ void mlx_save_safetensor_helper(
|
|||||||
} else {
|
} else {
|
||||||
metadata_map = std::unordered_map<std::string, std::string>();
|
metadata_map = std::unordered_map<std::string, std::string>();
|
||||||
}
|
}
|
||||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
|
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d);
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
|
mx::save_safetensors(
|
||||||
|
nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||||
}
|
}
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
save_safetensors(writer, arrays_map, metadata_map);
|
mx::save_safetensors(writer, arrays_map, metadata_map);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -492,19 +494,20 @@ void mlx_save_gguf_helper(
|
|||||||
nb::object file,
|
nb::object file,
|
||||||
nb::dict a,
|
nb::dict a,
|
||||||
std::optional<nb::dict> m) {
|
std::optional<nb::dict> m) {
|
||||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
|
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a);
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
if (m) {
|
if (m) {
|
||||||
auto metadata_map =
|
auto metadata_map =
|
||||||
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
|
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>(
|
||||||
|
m.value());
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
mx::save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
save_gguf(nb::cast<std::string>(file), arrays_map);
|
mx::save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -14,22 +14,24 @@
|
|||||||
#include <variant>
|
#include <variant>
|
||||||
#include "mlx/io.h"
|
#include "mlx/io.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
using LoadOutputTypes = std::variant<
|
using LoadOutputTypes = std::variant<
|
||||||
array,
|
mx::array,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, mx::array>,
|
||||||
SafetensorsLoad,
|
mx::SafetensorsLoad,
|
||||||
GGUFLoad>;
|
mx::GGUFLoad>;
|
||||||
|
|
||||||
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
|
mx::SafetensorsLoad mlx_load_safetensor_helper(
|
||||||
|
nb::object file,
|
||||||
|
mx::StreamOrDevice s);
|
||||||
void mlx_save_safetensor_helper(
|
void mlx_save_safetensor_helper(
|
||||||
nb::object file,
|
nb::object file,
|
||||||
nb::dict d,
|
nb::dict d,
|
||||||
std::optional<nb::dict> m);
|
std::optional<nb::dict> m);
|
||||||
|
|
||||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
|
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s);
|
||||||
|
|
||||||
void mlx_save_gguf_helper(
|
void mlx_save_gguf_helper(
|
||||||
nb::object file,
|
nb::object file,
|
||||||
@ -40,8 +42,8 @@ LoadOutputTypes mlx_load_helper(
|
|||||||
nb::object file,
|
nb::object file,
|
||||||
std::optional<std::string> format,
|
std::optional<std::string> format,
|
||||||
bool return_metadata,
|
bool return_metadata,
|
||||||
StreamOrDevice s);
|
mx::StreamOrDevice s);
|
||||||
void mlx_save_helper(nb::object file, array a);
|
void mlx_save_helper(nb::object file, mx::array a);
|
||||||
void mlx_savez_helper(
|
void mlx_savez_helper(
|
||||||
nb::object file,
|
nb::object file,
|
||||||
nb::args args,
|
nb::args args,
|
||||||
|
@ -8,22 +8,21 @@
|
|||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void init_metal(nb::module_& m) {
|
void init_metal(nb::module_& m) {
|
||||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||||
metal.def(
|
metal.def(
|
||||||
"is_available",
|
"is_available",
|
||||||
&metal::is_available,
|
&mx::metal::is_available,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if the Metal back-end is available.
|
Check if the Metal back-end is available.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"get_active_memory",
|
"get_active_memory",
|
||||||
&metal::get_active_memory,
|
&mx::metal::get_active_memory,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Get the actively used memory in bytes.
|
Get the actively used memory in bytes.
|
||||||
|
|
||||||
@ -32,7 +31,7 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"get_peak_memory",
|
"get_peak_memory",
|
||||||
&metal::get_peak_memory,
|
&mx::metal::get_peak_memory,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Get the peak amount of used memory in bytes.
|
Get the peak amount of used memory in bytes.
|
||||||
|
|
||||||
@ -41,13 +40,13 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"reset_peak_memory",
|
"reset_peak_memory",
|
||||||
&metal::reset_peak_memory,
|
&mx::metal::reset_peak_memory,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Reset the peak memory to zero.
|
Reset the peak memory to zero.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"get_cache_memory",
|
"get_cache_memory",
|
||||||
&metal::get_cache_memory,
|
&mx::metal::get_cache_memory,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Get the cache size in bytes.
|
Get the cache size in bytes.
|
||||||
|
|
||||||
@ -56,7 +55,7 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_memory_limit",
|
"set_memory_limit",
|
||||||
&metal::set_memory_limit,
|
&mx::metal::set_memory_limit,
|
||||||
"limit"_a,
|
"limit"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"relaxed"_a = true,
|
"relaxed"_a = true,
|
||||||
@ -81,7 +80,7 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_cache_limit",
|
"set_cache_limit",
|
||||||
&metal::set_cache_limit,
|
&mx::metal::set_cache_limit,
|
||||||
"limit"_a,
|
"limit"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Set the free cache limit.
|
Set the free cache limit.
|
||||||
@ -101,7 +100,7 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"set_wired_limit",
|
"set_wired_limit",
|
||||||
&metal::set_wired_limit,
|
&mx::metal::set_wired_limit,
|
||||||
"limit"_a,
|
"limit"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Set the wired size limit.
|
Set the wired size limit.
|
||||||
@ -133,7 +132,7 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"clear_cache",
|
"clear_cache",
|
||||||
&metal::clear_cache,
|
&mx::metal::clear_cache,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Clear the memory cache.
|
Clear the memory cache.
|
||||||
|
|
||||||
@ -142,7 +141,7 @@ void init_metal(nb::module_& m) {
|
|||||||
|
|
||||||
metal.def(
|
metal.def(
|
||||||
"start_capture",
|
"start_capture",
|
||||||
&metal::start_capture,
|
&mx::metal::start_capture,
|
||||||
"path"_a,
|
"path"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Start a Metal capture.
|
Start a Metal capture.
|
||||||
@ -153,13 +152,13 @@ void init_metal(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"stop_capture",
|
"stop_capture",
|
||||||
&metal::stop_capture,
|
&mx::metal::stop_capture,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Stop a Metal capture.
|
Stop a Metal capture.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"device_info",
|
"device_info",
|
||||||
&metal::device_info,
|
&mx::metal::device_info,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Get information about the GPU device and system settings.
|
Get information about the GPU device and system settings.
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -12,23 +12,22 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
using namespace mlx::core::random;
|
|
||||||
|
|
||||||
class PyKeySequence {
|
class PyKeySequence {
|
||||||
public:
|
public:
|
||||||
explicit PyKeySequence(uint64_t seed) {
|
explicit PyKeySequence(uint64_t seed) {
|
||||||
state_.append(key(seed));
|
state_.append(mx::random::key(seed));
|
||||||
}
|
}
|
||||||
|
|
||||||
void seed(uint64_t seed) {
|
void seed(uint64_t seed) {
|
||||||
state_[0] = key(seed);
|
state_[0] = mx::random::key(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
array next() {
|
mx::array next() {
|
||||||
auto out = split(nb::cast<array>(state_[0]));
|
auto out = mx::random::split(nb::cast<mx::array>(state_[0]));
|
||||||
state_[0] = out.first;
|
state_[0] = out.first;
|
||||||
return out.second;
|
return out.second;
|
||||||
}
|
}
|
||||||
@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"key",
|
"key",
|
||||||
&key,
|
&mx::random::key,
|
||||||
"seed"_a,
|
"seed"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Get a PRNG key from a seed.
|
Get a PRNG key from a seed.
|
||||||
@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"split",
|
"split",
|
||||||
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
nb::overload_cast<const mx::array&, int, mx::StreamOrDevice>(
|
||||||
|
&mx::random::split),
|
||||||
"key"_a,
|
"key"_a,
|
||||||
"num"_a = 2,
|
"num"_a = 2,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return uniform(
|
return mx::random::uniform(
|
||||||
to_array(low),
|
to_array(low),
|
||||||
to_array(high),
|
to_array(high),
|
||||||
shape,
|
shape,
|
||||||
type.value_or(float32),
|
type.value_or(mx::float32),
|
||||||
key,
|
key,
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
"low"_a = 0,
|
"low"_a = 0,
|
||||||
"high"_a = 1,
|
"high"_a = 1,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"normal",
|
"normal",
|
||||||
[](const std::vector<int>& shape,
|
[](const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
float loc,
|
float loc,
|
||||||
float scale,
|
float scale,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
return mx::random::normal(
|
||||||
|
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"loc"_a = 0.0,
|
"loc"_a = 0.0,
|
||||||
"scale"_a = 1.0,
|
"scale"_a = 1.0,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"multivariate_normal",
|
"multivariate_normal",
|
||||||
[](const array& mean,
|
[](const mx::array& mean,
|
||||||
const array& cov,
|
const mx::array& cov,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return multivariate_normal(
|
return mx::random::multivariate_normal(
|
||||||
mean, cov, shape, type.value_or(float32), key, s);
|
mean, cov, shape, type.value_or(mx::float32), key, s);
|
||||||
},
|
},
|
||||||
"mean"_a,
|
"mean"_a,
|
||||||
"cov"_a,
|
"cov"_a,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return randint(
|
return mx::random::randint(
|
||||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
to_array(low),
|
||||||
|
to_array(high),
|
||||||
|
shape,
|
||||||
|
type.value_or(mx::int32),
|
||||||
|
key,
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
"low"_a,
|
"low"_a,
|
||||||
"high"_a,
|
"high"_a,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = int32,
|
"dtype"_a.none() = mx::int32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"bernoulli",
|
"bernoulli",
|
||||||
[](const ScalarOrArray& p_,
|
[](const ScalarOrArray& p_,
|
||||||
const std::optional<std::vector<int>> shape,
|
const std::optional<std::vector<int>> shape,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
auto p = to_array(p_);
|
auto p = to_array(p_);
|
||||||
if (shape.has_value()) {
|
if (shape.has_value()) {
|
||||||
return bernoulli(p, shape.value(), key, s);
|
return mx::random::bernoulli(p, shape.value(), key, s);
|
||||||
} else {
|
} else {
|
||||||
return bernoulli(p, key, s);
|
return mx::random::bernoulli(p, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"p"_a = 0.5,
|
"p"_a = 0.5,
|
||||||
@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& lower_,
|
[](const ScalarOrArray& lower_,
|
||||||
const ScalarOrArray& upper_,
|
const ScalarOrArray& upper_,
|
||||||
const std::optional<std::vector<int>> shape_,
|
const std::optional<std::vector<int>> shape_,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
auto lower = to_array(lower_);
|
auto lower = to_array(lower_);
|
||||||
auto upper = to_array(upper_);
|
auto upper = to_array(upper_);
|
||||||
auto t = type.value_or(float32);
|
auto t = type.value_or(mx::float32);
|
||||||
if (shape_.has_value()) {
|
if (shape_.has_value()) {
|
||||||
return truncated_normal(lower, upper, shape_.value(), t, key, s);
|
return mx::random::truncated_normal(
|
||||||
|
lower, upper, shape_.value(), t, key, s);
|
||||||
} else {
|
} else {
|
||||||
return truncated_normal(lower, upper, t, key, s);
|
return mx::random::truncated_normal(lower, upper, t, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"lower"_a,
|
"lower"_a,
|
||||||
"upper"_a,
|
"upper"_a,
|
||||||
"shape"_a = nb::none(),
|
"shape"_a = nb::none(),
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"gumbel",
|
"gumbel",
|
||||||
[](const std::vector<int>& shape,
|
[](const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return gumbel(shape, type.value_or(float32), key, s);
|
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"categorical",
|
"categorical",
|
||||||
[](const array& logits,
|
[](const mx::array& logits,
|
||||||
int axis,
|
int axis,
|
||||||
const std::optional<std::vector<int>> shape,
|
const std::optional<std::vector<int>> shape,
|
||||||
const std::optional<int> num_samples,
|
const std::optional<int> num_samples,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
if (shape.has_value() && num_samples.has_value()) {
|
if (shape.has_value() && num_samples.has_value()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[categorical] At most one of shape or num_samples can be specified.");
|
"[categorical] At most one of shape or num_samples can be specified.");
|
||||||
} else if (shape.has_value()) {
|
} else if (shape.has_value()) {
|
||||||
return categorical(logits, axis, shape.value(), key, s);
|
return mx::random::categorical(logits, axis, shape.value(), key, s);
|
||||||
} else if (num_samples.has_value()) {
|
} else if (num_samples.has_value()) {
|
||||||
return categorical(logits, axis, num_samples.value(), key, s);
|
return mx::random::categorical(
|
||||||
|
logits, axis, num_samples.value(), key, s);
|
||||||
} else {
|
} else {
|
||||||
return categorical(logits, axis, key, s);
|
return mx::random::categorical(logits, axis, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"logits"_a,
|
"logits"_a,
|
||||||
@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"laplace",
|
"laplace",
|
||||||
[](const std::vector<int>& shape,
|
[](const std::vector<int>& shape,
|
||||||
std::optional<Dtype> type,
|
std::optional<mx::Dtype> type,
|
||||||
float loc,
|
float loc,
|
||||||
float scale,
|
float scale,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
return laplace(shape, type.value_or(float32), loc, scale, key, s);
|
return mx::random::laplace(
|
||||||
|
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a.none() = float32,
|
"dtype"_a.none() = mx::float32,
|
||||||
"loc"_a = 0.0,
|
"loc"_a = 0.0,
|
||||||
"scale"_a = 1.0,
|
"scale"_a = 1.0,
|
||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"permuation",
|
"permuation",
|
||||||
[](const std::variant<nb::int_, array>& x,
|
[](const std::variant<nb::int_, mx::array>& x,
|
||||||
int axis,
|
int axis,
|
||||||
const std::optional<array>& key_,
|
const std::optional<mx::array>& key_,
|
||||||
StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
auto key = key_ ? key_.value() : default_key().next();
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
||||||
return permutation(nb::cast<int>(*pv), key, s);
|
return mx::random::permutation(nb::cast<int>(*pv), key, s);
|
||||||
} else {
|
} else {
|
||||||
return permutation(std::get<array>(x), axis, key, s);
|
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
|
@ -10,14 +10,14 @@
|
|||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
// Create the StreamContext on enter and delete on exit.
|
// Create the StreamContext on enter and delete on exit.
|
||||||
class PyStreamContext {
|
class PyStreamContext {
|
||||||
public:
|
public:
|
||||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) {
|
||||||
if (std::holds_alternative<std::monostate>(s)) {
|
if (std::holds_alternative<std::monostate>(s)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||||
@ -26,7 +26,7 @@ class PyStreamContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void enter() {
|
void enter() {
|
||||||
_inner = new StreamContext(_s);
|
_inner = new mx::StreamContext(_s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void exit() {
|
void exit() {
|
||||||
@ -37,39 +37,40 @@ class PyStreamContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
StreamOrDevice _s;
|
mx::StreamOrDevice _s;
|
||||||
StreamContext* _inner;
|
mx::StreamContext* _inner;
|
||||||
};
|
};
|
||||||
|
|
||||||
void init_stream(nb::module_& m) {
|
void init_stream(nb::module_& m) {
|
||||||
nb::class_<Stream>(
|
nb::class_<mx::Stream>(
|
||||||
m,
|
m,
|
||||||
"Stream",
|
"Stream",
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A stream for running operations on a given device.
|
A stream for running operations on a given device.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def_ro("device", &Stream::device)
|
.def_ro("device", &mx::Stream::device)
|
||||||
.def(
|
.def(
|
||||||
"__repr__",
|
"__repr__",
|
||||||
[](const Stream& s) {
|
[](const mx::Stream& s) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << s;
|
os << s;
|
||||||
return os.str();
|
return os.str();
|
||||||
})
|
})
|
||||||
.def("__eq__", [](const Stream& s, const nb::object& other) {
|
.def("__eq__", [](const mx::Stream& s, const nb::object& other) {
|
||||||
return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
|
return nb::isinstance<mx::Stream>(other) &&
|
||||||
|
s == nb::cast<mx::Stream>(other);
|
||||||
});
|
});
|
||||||
|
|
||||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"default_stream",
|
"default_stream",
|
||||||
&default_stream,
|
&mx::default_stream,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Get the device's default stream.)pbdoc");
|
R"pbdoc(Get the device's default stream.)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"set_default_stream",
|
"set_default_stream",
|
||||||
&set_default_stream,
|
&mx::set_default_stream,
|
||||||
"stream"_a,
|
"stream"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Set the default stream.
|
Set the default stream.
|
||||||
@ -82,7 +83,7 @@ void init_stream(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"new_stream",
|
"new_stream",
|
||||||
&new_stream,
|
&mx::new_stream,
|
||||||
"device"_a,
|
"device"_a,
|
||||||
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
||||||
|
|
||||||
@ -94,7 +95,7 @@ void init_stream(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
s: The stream or device to set as the default.
|
s: The stream or device to set as the default.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(nb::init<StreamOrDevice>(), "s"_a)
|
.def(nb::init<mx::StreamOrDevice>(), "s"_a)
|
||||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||||
.def(
|
.def(
|
||||||
"__exit__",
|
"__exit__",
|
||||||
@ -107,7 +108,7 @@ void init_stream(nb::module_& m) {
|
|||||||
"traceback"_a = nb::none());
|
"traceback"_a = nb::none());
|
||||||
m.def(
|
m.def(
|
||||||
"stream",
|
"stream",
|
||||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
[](mx::StreamOrDevice s) { return PyStreamContext(s); },
|
||||||
"s"_a,
|
"s"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Create a context manager to set the default device and stream.
|
Create a context manager to set the default device and stream.
|
||||||
@ -131,8 +132,8 @@ void init_stream(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"synchronize",
|
"synchronize",
|
||||||
[](const std::optional<Stream>& s) {
|
[](const std::optional<mx::Stream>& s) {
|
||||||
s ? synchronize(s.value()) : synchronize();
|
s ? mx::synchronize(s.value()) : mx::synchronize();
|
||||||
},
|
},
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
@ -20,9 +20,12 @@
|
|||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
|
||||||
|
// Needed for printing shapes and strides.
|
||||||
|
using mx::operator<<;
|
||||||
|
|
||||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||||
@ -108,7 +111,7 @@ auto py_value_and_grad(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Collect the arrays
|
// Collect the arrays
|
||||||
std::vector<array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
std::vector<int> counts(1, 0);
|
std::vector<int> counts(1, 0);
|
||||||
for (auto i : argnums) {
|
for (auto i : argnums) {
|
||||||
auto argsi = tree_flatten(args[i]);
|
auto argsi = tree_flatten(args[i]);
|
||||||
@ -127,7 +130,7 @@ auto py_value_and_grad(
|
|||||||
// value_out will hold the output of the python function in order to be
|
// value_out will hold the output of the python function in order to be
|
||||||
// able to reconstruct the python tree of extra return values
|
// able to reconstruct the python tree of extra return values
|
||||||
nb::object py_value_out;
|
nb::object py_value_out;
|
||||||
auto value_and_grads = value_and_grad(
|
auto value_and_grads = mx::value_and_grad(
|
||||||
[&fun,
|
[&fun,
|
||||||
&args,
|
&args,
|
||||||
&kwargs,
|
&kwargs,
|
||||||
@ -136,7 +139,7 @@ auto py_value_and_grad(
|
|||||||
&counts,
|
&counts,
|
||||||
&py_value_out,
|
&py_value_out,
|
||||||
&error_msg_tag,
|
&error_msg_tag,
|
||||||
scalar_func_only](const std::vector<array>& a) {
|
scalar_func_only](const std::vector<mx::array>& a) {
|
||||||
// Copy the arguments
|
// Copy the arguments
|
||||||
nb::list args_cpy;
|
nb::list args_cpy;
|
||||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||||
@ -165,7 +168,7 @@ auto py_value_and_grad(
|
|||||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||||
|
|
||||||
// Validate the return value of the python function
|
// Validate the return value of the python function
|
||||||
if (!nb::isinstance<array>(py_value_out)) {
|
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||||
if (scalar_func_only) {
|
if (scalar_func_only) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << error_msg_tag << " The return value of the function "
|
msg << error_msg_tag << " The return value of the function "
|
||||||
@ -193,7 +196,7 @@ auto py_value_and_grad(
|
|||||||
<< "we got an empty tuple.";
|
<< "we got an empty tuple.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (!nb::isinstance<array>(ret[0])) {
|
if (!nb::isinstance<mx::array>(ret[0])) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << error_msg_tag << " The return value of the function "
|
msg << error_msg_tag << " The return value of the function "
|
||||||
<< "whose gradient we want to compute should be either a "
|
<< "whose gradient we want to compute should be either a "
|
||||||
@ -275,12 +278,12 @@ auto py_vmap(
|
|||||||
{tree, axes},
|
{tree, axes},
|
||||||
[&flat_axes, &encountered_tuple, output_axes](
|
[&flat_axes, &encountered_tuple, output_axes](
|
||||||
const std::vector<nb::object>& inputs) {
|
const std::vector<nb::object>& inputs) {
|
||||||
if (nb::isinstance<array>(inputs[0])) {
|
if (nb::isinstance<mx::array>(inputs[0])) {
|
||||||
if (inputs[1].is_none()) {
|
if (inputs[1].is_none()) {
|
||||||
flat_axes.push_back(-1);
|
flat_axes.push_back(-1);
|
||||||
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||||
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||||
const array& x = nb::cast<array>(inputs[0]);
|
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += x.ndim() + output_axes;
|
axis += x.ndim() + output_axes;
|
||||||
}
|
}
|
||||||
@ -297,7 +300,7 @@ auto py_vmap(
|
|||||||
auto l = nb::cast<nb::tuple>(inputs[1]);
|
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||||
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||||
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||||
const array& x = nb::cast<array>(inputs[0]);
|
const mx::array& x = nb::cast<mx::array>(inputs[0]);
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += x.ndim() + output_axes;
|
axis += x.ndim() + output_axes;
|
||||||
}
|
}
|
||||||
@ -323,7 +326,7 @@ auto py_vmap(
|
|||||||
"[vmap] The arguments should contain only arrays");
|
"[vmap] The arguments should contain only arrays");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
|
||||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||||
}
|
}
|
||||||
return flat_axes;
|
return flat_axes;
|
||||||
@ -339,7 +342,7 @@ auto py_vmap(
|
|||||||
nb::object py_outputs;
|
nb::object py_outputs;
|
||||||
|
|
||||||
auto vmap_fn =
|
auto vmap_fn =
|
||||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py_outputs = fun(*tree_unflatten(args, a));
|
py_outputs = fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
@ -348,12 +351,12 @@ auto py_vmap(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto [trace_inputs, trace_outputs] =
|
auto [trace_inputs, trace_outputs] =
|
||||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||||
|
|
||||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||||
|
|
||||||
// Perform the vmap
|
// Perform the vmap
|
||||||
auto outputs = detail::vmap_replace(
|
auto outputs = mx::detail::vmap_replace(
|
||||||
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||||
|
|
||||||
// Put the outputs back in the container
|
// Put the outputs back in the container
|
||||||
@ -401,7 +404,7 @@ struct PyCompiledFun {
|
|||||||
|
|
||||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||||
// Flat array inputs
|
// Flat array inputs
|
||||||
std::vector<array> inputs;
|
std::vector<mx::array> inputs;
|
||||||
|
|
||||||
// Compilation constants which includes the tree structure of the arguments
|
// Compilation constants which includes the tree structure of the arguments
|
||||||
std::vector<uint64_t> constants;
|
std::vector<uint64_t> constants;
|
||||||
@ -437,8 +440,8 @@ struct PyCompiledFun {
|
|||||||
constants.push_back(nb::cast<int64_t>(r));
|
constants.push_back(nb::cast<int64_t>(r));
|
||||||
recurse(item.second);
|
recurse(item.second);
|
||||||
}
|
}
|
||||||
} else if (nb::isinstance<array>(obj)) {
|
} else if (nb::isinstance<mx::array>(obj)) {
|
||||||
inputs.push_back(nb::cast<array>(obj));
|
inputs.push_back(nb::cast<mx::array>(obj));
|
||||||
constants.push_back(array_identifier);
|
constants.push_back(array_identifier);
|
||||||
} else if (nb::isinstance<nb::str>(obj)) {
|
} else if (nb::isinstance<nb::str>(obj)) {
|
||||||
auto r = obj.attr("__hash__")();
|
auto r = obj.attr("__hash__")();
|
||||||
@ -461,10 +464,10 @@ struct PyCompiledFun {
|
|||||||
int num_args = inputs.size();
|
int num_args = inputs.size();
|
||||||
recurse(kwargs);
|
recurse(kwargs);
|
||||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||||
const std::vector<array>& a) {
|
const std::vector<mx::array>& a) {
|
||||||
// Put tracers into captured inputs
|
// Put tracers into captured inputs
|
||||||
std::vector<array> flat_in_captures;
|
std::vector<mx::array> flat_in_captures;
|
||||||
std::vector<array> trace_captures;
|
std::vector<mx::array> trace_captures;
|
||||||
if (!captured_inputs.is_none()) {
|
if (!captured_inputs.is_none()) {
|
||||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||||
trace_captures.insert(
|
trace_captures.insert(
|
||||||
@ -505,9 +508,9 @@ struct PyCompiledFun {
|
|||||||
|
|
||||||
// Compile and call
|
// Compile and call
|
||||||
auto outputs =
|
auto outputs =
|
||||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||||
if (!captured_outputs.is_none()) {
|
if (!captured_outputs.is_none()) {
|
||||||
std::vector<array> captures(
|
std::vector<mx::array> captures(
|
||||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||||
std::make_move_iterator(outputs.end()));
|
std::make_move_iterator(outputs.end()));
|
||||||
tree_fill(captured_outputs, captures);
|
tree_fill(captured_outputs, captures);
|
||||||
@ -526,7 +529,7 @@ struct PyCompiledFun {
|
|||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
tree_cache().erase(fun_id);
|
tree_cache().erase(fun_id);
|
||||||
detail::compile_erase(fun_id);
|
mx::detail::compile_erase(fun_id);
|
||||||
fun.release().dec_ref();
|
fun.release().dec_ref();
|
||||||
captured_inputs.release().dec_ref();
|
captured_inputs.release().dec_ref();
|
||||||
captured_outputs.release().dec_ref();
|
captured_outputs.release().dec_ref();
|
||||||
@ -561,7 +564,7 @@ class PyCheckpointedFun {
|
|||||||
args_structure_.release().dec_ref();
|
args_structure_.release().dec_ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||||
auto args = nb::cast<nb::tuple>(
|
auto args = nb::cast<nb::tuple>(
|
||||||
tree_unflatten_from_structure(args_structure_, inputs));
|
tree_unflatten_from_structure(args_structure_, inputs));
|
||||||
auto [outputs, output_structure] =
|
auto [outputs, output_structure] =
|
||||||
@ -579,7 +582,7 @@ class PyCheckpointedFun {
|
|||||||
auto [inputs, args_structure] =
|
auto [inputs, args_structure] =
|
||||||
tree_flatten_with_structure(full_args, false);
|
tree_flatten_with_structure(full_args, false);
|
||||||
|
|
||||||
auto outputs = checkpoint(
|
auto outputs = mx::checkpoint(
|
||||||
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
InnerFunction(fun_, args_structure, output_structure))(inputs);
|
||||||
|
|
||||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||||
@ -660,12 +663,12 @@ class PyCustomFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
auto new_inputs = nb::cast<nb::tuple>(
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
tree_unflatten_from_structure(input_structure_, inputs));
|
tree_unflatten_from_structure(input_structure_, inputs));
|
||||||
std::vector<array> outputs;
|
std::vector<mx::array> outputs;
|
||||||
std::tie(outputs, *output_structure_) =
|
std::tie(outputs, *output_structure_) =
|
||||||
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||||
return outputs;
|
return outputs;
|
||||||
@ -694,10 +697,10 @@ class PyCustomFunction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> operator()(
|
std::vector<mx::array> operator()(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<mx::array>& cotangents,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<mx::array>& outputs) {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
auto new_inputs = nb::cast<nb::tuple>(
|
auto new_inputs = nb::cast<nb::tuple>(
|
||||||
@ -734,9 +737,9 @@ class PyCustomFunction {
|
|||||||
input_structure_.release().dec_ref();
|
input_structure_.release().dec_ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> operator()(
|
std::vector<mx::array> operator()(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<mx::array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
@ -759,7 +762,7 @@ class PyCustomFunction {
|
|||||||
int tangent_index = 0;
|
int tangent_index = 0;
|
||||||
auto new_tangents =
|
auto new_tangents =
|
||||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||||
if (nb::isinstance<array>(element) &&
|
if (nb::isinstance<mx::array>(element) &&
|
||||||
have_tangents[array_index++]) {
|
have_tangents[array_index++]) {
|
||||||
return nb::cast(tangents[tangent_index++]);
|
return nb::cast(tangents[tangent_index++]);
|
||||||
} else {
|
} else {
|
||||||
@ -789,8 +792,8 @@ class PyCustomFunction {
|
|||||||
input_structure_.release().dec_ref();
|
input_structure_.release().dec_ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> operator()(
|
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
@ -807,7 +810,7 @@ class PyCustomFunction {
|
|||||||
auto new_axes =
|
auto new_axes =
|
||||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||||
int axis = axes[arr_index++];
|
int axis = axes[arr_index++];
|
||||||
if (nb::isinstance<array>(element) && axis >= 0) {
|
if (nb::isinstance<mx::array>(element) && axis >= 0) {
|
||||||
return nb::cast(axis);
|
return nb::cast(axis);
|
||||||
} else {
|
} else {
|
||||||
return nb::none();
|
return nb::none();
|
||||||
@ -831,11 +834,11 @@ class PyCustomFunction {
|
|||||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> outputs;
|
std::vector<mx::array> outputs;
|
||||||
std::vector<int> output_axes;
|
std::vector<int> output_axes;
|
||||||
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||||
if (nb::isinstance<array>(objects[0])) {
|
if (nb::isinstance<mx::array>(objects[0])) {
|
||||||
outputs.push_back(nb::cast<array>(objects[0]));
|
outputs.push_back(nb::cast<mx::array>(objects[0]));
|
||||||
output_axes.push_back(
|
output_axes.push_back(
|
||||||
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||||
}
|
}
|
||||||
@ -852,7 +855,7 @@ class PyCustomFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract the inputs and their structure in capturable vars
|
// Extract the inputs and their structure in capturable vars
|
||||||
std::vector<array> input_arrays;
|
std::vector<mx::array> input_arrays;
|
||||||
nb::object input_structure;
|
nb::object input_structure;
|
||||||
auto full_args = nb::make_tuple(args, kwargs);
|
auto full_args = nb::make_tuple(args, kwargs);
|
||||||
std::tie(input_arrays, input_structure) =
|
std::tie(input_arrays, input_structure) =
|
||||||
@ -864,7 +867,7 @@ class PyCustomFunction {
|
|||||||
|
|
||||||
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||||
// backward pass. Then call it immediately and return the results.
|
// backward pass. Then call it immediately and return the results.
|
||||||
auto f = custom_function(
|
auto f = mx::custom_function(
|
||||||
InnerFunction(fun_, input_structure, output_structure),
|
InnerFunction(fun_, input_structure, output_structure),
|
||||||
make_vjp_function(input_structure, output_structure),
|
make_vjp_function(input_structure, output_structure),
|
||||||
make_jvp_function(input_structure),
|
make_jvp_function(input_structure),
|
||||||
@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"eval",
|
"eval",
|
||||||
[](const nb::args& args) {
|
[](const nb::args& args) {
|
||||||
std::vector<array> arrays = tree_flatten(args, false);
|
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
eval(arrays);
|
eval(arrays);
|
||||||
@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"async_eval",
|
"async_eval",
|
||||||
[](const nb::args& args) {
|
[](const nb::args& args) {
|
||||||
std::vector<array> arrays = tree_flatten(args, false);
|
std::vector<mx::array> arrays = tree_flatten(args, false);
|
||||||
{
|
{
|
||||||
nb::gil_scoped_release nogil;
|
nb::gil_scoped_release nogil;
|
||||||
async_eval(arrays);
|
async_eval(arrays);
|
||||||
@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"jvp",
|
"jvp",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& tangents) {
|
const std::vector<mx::array>& tangents) {
|
||||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||||
auto out = fun(*nb::cast(primals));
|
auto out = fun(*nb::cast(primals));
|
||||||
if (nb::isinstance<array>(out)) {
|
if (nb::isinstance<mx::array>(out)) {
|
||||||
return std::vector<array>{nb::cast<array>(out)};
|
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<std::vector<array>>(out);
|
return nb::cast<std::vector<mx::array>>(out);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return jvp(vfun, primals, tangents);
|
return jvp(vfun, primals, tangents);
|
||||||
@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"vjp",
|
"vjp",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& cotangents) {
|
const std::vector<mx::array>& cotangents) {
|
||||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
auto vfun = [&fun](const std::vector<mx::array>& primals) {
|
||||||
auto out = fun(*nb::cast(primals));
|
auto out = fun(*nb::cast(primals));
|
||||||
if (nb::isinstance<array>(out)) {
|
if (nb::isinstance<mx::array>(out)) {
|
||||||
return std::vector<array>{nb::cast<array>(out)};
|
return std::vector<mx::array>{nb::cast<mx::array>(out)};
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<std::vector<array>>(out);
|
return nb::cast<std::vector<mx::array>>(out);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return vjp(vfun, primals, cotangents);
|
return vjp(vfun, primals, cotangents);
|
||||||
@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) {
|
|||||||
m.def(
|
m.def(
|
||||||
"export_to_dot",
|
"export_to_dot",
|
||||||
[](nb::object file, const nb::args& args) {
|
[](nb::object file, const nb::args& args) {
|
||||||
std::vector<array> arrays = tree_flatten(args);
|
std::vector<mx::array> arrays = tree_flatten(args);
|
||||||
if (nb::isinstance<nb::str>(file)) {
|
if (nb::isinstance<nb::str>(file)) {
|
||||||
std::ofstream out(nb::cast<std::string>(file));
|
std::ofstream out(nb::cast<std::string>(file));
|
||||||
export_to_dot(out, arrays);
|
export_to_dot(out, arrays);
|
||||||
@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"disable_compile",
|
"disable_compile",
|
||||||
&disable_compile,
|
&mx::disable_compile,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Globally disable compilation. Setting the environment variable
|
Globally disable compilation. Setting the environment variable
|
||||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"enable_compile",
|
"enable_compile",
|
||||||
&enable_compile,
|
&mx::enable_compile,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Globally enable compilation. This will override the environment
|
Globally enable compilation. This will override the environment
|
||||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||||
@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) {
|
|||||||
auto atexit = nb::module_::import_("atexit");
|
auto atexit = nb::module_::import_("atexit");
|
||||||
atexit.attr("register")(nb::cpp_function([]() {
|
atexit.attr("register")(nb::cpp_function([]() {
|
||||||
tree_cache().clear();
|
tree_cache().clear();
|
||||||
detail::compile_clear_cache();
|
mx::detail::compile_clear_cache();
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ void tree_visit_update(
|
|||||||
d[item.first] = recurse(item.second);
|
d[item.first] = recurse(item.second);
|
||||||
}
|
}
|
||||||
return nb::cast<nb::object>(d);
|
return nb::cast<nb::object>(d);
|
||||||
} else if (nb::isinstance<array>(subtree)) {
|
} else if (nb::isinstance<mx::array>(subtree)) {
|
||||||
return visitor(subtree);
|
return visitor(subtree);
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<nb::object>(subtree);
|
return nb::cast<nb::object>(subtree);
|
||||||
@ -200,7 +200,7 @@ void tree_visit_update(
|
|||||||
// Fill a pytree (recursive dict or list of dict or list)
|
// Fill a pytree (recursive dict or list of dict or list)
|
||||||
// in place with the given arrays
|
// in place with the given arrays
|
||||||
// Non dict or list nodes are ignored
|
// Non dict or list nodes are ignored
|
||||||
void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
tree_visit_update(
|
tree_visit_update(
|
||||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
||||||
@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
|||||||
// Replace all the arrays from the src values with the dst values in the tree
|
// Replace all the arrays from the src values with the dst values in the tree
|
||||||
void tree_replace(
|
void tree_replace(
|
||||||
nb::object& tree,
|
nb::object& tree,
|
||||||
const std::vector<array>& src,
|
const std::vector<mx::array>& src,
|
||||||
const std::vector<array>& dst) {
|
const std::vector<mx::array>& dst) {
|
||||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
||||||
for (int i = 0; i < src.size(); ++i) {
|
for (int i = 0; i < src.size(); ++i) {
|
||||||
src_to_dst.insert({src[i].id(), dst[i]});
|
src_to_dst.insert({src[i].id(), dst[i]});
|
||||||
}
|
}
|
||||||
tree_visit_update(tree, [&](nb::handle node) {
|
tree_visit_update(tree, [&](nb::handle node) {
|
||||||
auto arr = nb::cast<array>(node);
|
auto arr = nb::cast<mx::array>(node);
|
||||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||||
return nb::cast(it->second);
|
return nb::cast(it->second);
|
||||||
}
|
}
|
||||||
@ -224,12 +224,12 @@ void tree_replace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||||
std::vector<array> flat_tree;
|
std::vector<mx::array> flat_tree;
|
||||||
|
|
||||||
tree_visit(tree, [&](nb::handle obj) {
|
tree_visit(tree, [&](nb::handle obj) {
|
||||||
if (nb::isinstance<array>(obj)) {
|
if (nb::isinstance<mx::array>(obj)) {
|
||||||
flat_tree.push_back(nb::cast<array>(obj));
|
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||||
} else if (strict) {
|
} else if (strict) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tree_flatten] The argument should contain only arrays");
|
"[tree_flatten] The argument should contain only arrays");
|
||||||
@ -241,10 +241,10 @@ std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
|||||||
|
|
||||||
nb::object tree_unflatten(
|
nb::object tree_unflatten(
|
||||||
nb::object tree,
|
nb::object tree,
|
||||||
const std::vector<array>& values,
|
const std::vector<mx::array>& values,
|
||||||
int index /* = 0 */) {
|
int index /* = 0 */) {
|
||||||
return tree_map(tree, [&](nb::handle obj) {
|
return tree_map(tree, [&](nb::handle obj) {
|
||||||
if (nb::isinstance<array>(obj)) {
|
if (nb::isinstance<mx::array>(obj)) {
|
||||||
return nb::cast(values[index++]);
|
return nb::cast(values[index++]);
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<nb::object>(obj);
|
return nb::cast<nb::object>(obj);
|
||||||
@ -265,16 +265,16 @@ nb::object structure_sentinel() {
|
|||||||
return sentinel;
|
return sentinel;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
|
||||||
nb::object tree,
|
nb::object tree,
|
||||||
bool strict /* = true */) {
|
bool strict /* = true */) {
|
||||||
auto sentinel = structure_sentinel();
|
auto sentinel = structure_sentinel();
|
||||||
std::vector<array> flat_tree;
|
std::vector<mx::array> flat_tree;
|
||||||
auto structure = tree_map(
|
auto structure = tree_map(
|
||||||
tree,
|
tree,
|
||||||
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
|
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
|
||||||
if (nb::isinstance<array>(obj)) {
|
if (nb::isinstance<mx::array>(obj)) {
|
||||||
flat_tree.push_back(nb::cast<array>(obj));
|
flat_tree.push_back(nb::cast<mx::array>(obj));
|
||||||
return sentinel;
|
return sentinel;
|
||||||
} else if (!strict) {
|
} else if (!strict) {
|
||||||
return nb::cast<nb::object>(obj);
|
return nb::cast<nb::object>(obj);
|
||||||
@ -289,7 +289,7 @@ std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
|||||||
|
|
||||||
nb::object tree_unflatten_from_structure(
|
nb::object tree_unflatten_from_structure(
|
||||||
nb::object structure,
|
nb::object structure,
|
||||||
const std::vector<array>& values,
|
const std::vector<mx::array>& values,
|
||||||
int index /* = 0 */) {
|
int index /* = 0 */) {
|
||||||
auto sentinel = structure_sentinel();
|
auto sentinel = structure_sentinel();
|
||||||
return tree_map(structure, [&](nb::handle obj) {
|
return tree_map(structure, [&](nb::handle obj) {
|
||||||
|
@ -4,8 +4,8 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
void tree_visit(
|
void tree_visit(
|
||||||
const std::vector<nb::object>& trees,
|
const std::vector<nb::object>& trees,
|
||||||
@ -27,7 +27,7 @@ void tree_visit_update(
|
|||||||
/**
|
/**
|
||||||
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||||
* given arrays. */
|
* given arrays. */
|
||||||
void tree_fill(nb::object& tree, const std::vector<array>& values);
|
void tree_fill(nb::object& tree, const std::vector<mx::array>& values);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Replace all the arrays from the src values with the dst values in the
|
* Replace all the arrays from the src values with the dst values in the
|
||||||
@ -35,28 +35,28 @@ void tree_fill(nb::object& tree, const std::vector<array>& values);
|
|||||||
*/
|
*/
|
||||||
void tree_replace(
|
void tree_replace(
|
||||||
nb::object& tree,
|
nb::object& tree,
|
||||||
const std::vector<array>& src,
|
const std::vector<mx::array>& src,
|
||||||
const std::vector<array>& dst);
|
const std::vector<mx::array>& dst);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||||
* function will throw if the tree contains a leaf which is not an array.
|
* function will throw if the tree contains a leaf which is not an array.
|
||||||
*/
|
*/
|
||||||
std::vector<array> tree_flatten(nb::object tree, bool strict = true);
|
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unflatten a tree from a vector of arrays.
|
* Unflatten a tree from a vector of arrays.
|
||||||
*/
|
*/
|
||||||
nb::object tree_unflatten(
|
nb::object tree_unflatten(
|
||||||
nb::object tree,
|
nb::object tree,
|
||||||
const std::vector<array>& values,
|
const std::vector<mx::array>& values,
|
||||||
int index = 0);
|
int index = 0);
|
||||||
|
|
||||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
|
||||||
nb::object tree,
|
nb::object tree,
|
||||||
bool strict = true);
|
bool strict = true);
|
||||||
|
|
||||||
nb::object tree_unflatten_from_structure(
|
nb::object tree_unflatten_from_structure(
|
||||||
nb::object structure,
|
nb::object structure,
|
||||||
const std::vector<array>& values,
|
const std::vector<mx::array>& values,
|
||||||
int index = 0);
|
int index = 0);
|
||||||
|
@ -4,22 +4,24 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "python/src/convert.h"
|
#include "python/src/convert.h"
|
||||||
|
|
||||||
array to_array(
|
mx::array to_array(
|
||||||
const ScalarOrArray& v,
|
const ScalarOrArray& v,
|
||||||
std::optional<Dtype> dtype /* = std::nullopt */) {
|
std::optional<mx::Dtype> dtype /* = std::nullopt */) {
|
||||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||||
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
|
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
|
||||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||||
auto out_t = dtype.value_or(int32);
|
auto out_t = dtype.value_or(mx::int32);
|
||||||
// bool_ is an exception and is always promoted
|
// bool_ is an exception and is always promoted
|
||||||
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
return mx::array(
|
||||||
|
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||||
auto out_t = dtype.value_or(float32);
|
auto out_t = dtype.value_or(mx::float32);
|
||||||
return array(
|
return mx::array(
|
||||||
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
nb::cast<float>(*pv),
|
||||||
|
mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32);
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
return array(static_cast<complex64_t>(*pv), complex64);
|
return mx::array(static_cast<mx::complex64_t>(*pv), mx::complex64);
|
||||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||||
return *pv;
|
return *pv;
|
||||||
} else if (auto pv = std::get_if<
|
} else if (auto pv = std::get_if<
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||||
@ -30,7 +32,7 @@ array to_array(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<array, array> to_arrays(
|
std::pair<mx::array, mx::array> to_arrays(
|
||||||
const ScalarOrArray& a,
|
const ScalarOrArray& a,
|
||||||
const ScalarOrArray& b) {
|
const ScalarOrArray& b) {
|
||||||
// Four cases:
|
// Four cases:
|
||||||
@ -39,15 +41,15 @@ std::pair<array, array> to_arrays(
|
|||||||
// - If b is an array but a is not, treat a as a weak python type
|
// - If b is an array but a is not, treat a as a weak python type
|
||||||
// - If neither is an array convert to arrays but leave their types alone
|
// - If neither is an array convert to arrays but leave their types alone
|
||||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||||
return std::holds_alternative<array>(x) ||
|
return std::holds_alternative<mx::array>(x) ||
|
||||||
std::holds_alternative<nb::object>(x) &&
|
std::holds_alternative<nb::object>(x) &&
|
||||||
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
||||||
};
|
};
|
||||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||||
if (auto px = std::get_if<array>(&x); px) {
|
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||||
return *px;
|
return *px;
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -66,11 +68,11 @@ std::pair<array, array> to_arrays(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array to_array_with_accessor(nb::object obj) {
|
mx::array to_array_with_accessor(nb::object obj) {
|
||||||
if (nb::isinstance<array>(obj)) {
|
if (nb::isinstance<mx::array>(obj)) {
|
||||||
return nb::cast<array>(obj);
|
return nb::cast<mx::array>(obj);
|
||||||
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
||||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
return nb::cast<mx::array>(obj.attr("__mlx_array__")());
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||||
|
@ -12,17 +12,16 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
using ScalarOrArray = std::variant<
|
using ScalarOrArray = std::variant<
|
||||||
nb::bool_,
|
nb::bool_,
|
||||||
nb::int_,
|
nb::int_,
|
||||||
nb::float_,
|
nb::float_,
|
||||||
// Must be above ndarray
|
// Must be above ndarray
|
||||||
array,
|
mx::array,
|
||||||
// Must be above complex
|
// Must be above complex
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||||
std::complex<float>,
|
std::complex<float>,
|
||||||
@ -45,7 +44,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
|||||||
// Checks if the value can be compared to an array (or is already an
|
// Checks if the value can be compared to an array (or is already an
|
||||||
// mlx array)
|
// mlx array)
|
||||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
||||||
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||||
} else {
|
} else {
|
||||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||||
// and can be compared to an array
|
// and can be compared to an array
|
||||||
@ -66,12 +65,12 @@ inline void throw_invalid_operation(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
array to_array(
|
mx::array to_array(
|
||||||
const ScalarOrArray& v,
|
const ScalarOrArray& v,
|
||||||
std::optional<Dtype> dtype = std::nullopt);
|
std::optional<mx::Dtype> dtype = std::nullopt);
|
||||||
|
|
||||||
std::pair<array, array> to_arrays(
|
std::pair<mx::array, mx::array> to_arrays(
|
||||||
const ScalarOrArray& a,
|
const ScalarOrArray& a,
|
||||||
const ScalarOrArray& b);
|
const ScalarOrArray& b);
|
||||||
|
|
||||||
array to_array_with_accessor(nb::object obj);
|
mx::array to_array_with_accessor(nb::object obj);
|
||||||
|
Loading…
Reference in New Issue
Block a user