mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-23 16:46:50 +08:00
parent
a6b426422e
commit
f110357aaa
@ -85,7 +85,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@ -226,7 +226,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@ -291,7 +291,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
|
@ -239,8 +239,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
@ -18,8 +18,7 @@ find_package(
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
@ -3,6 +3,6 @@ requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.2.0",
|
||||
"nanobind==2.4.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@ -1,7 +1,7 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"nanobind==2.2.0",
|
||||
"nanobind==2.4.0",
|
||||
"cmake>=3.25",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@ -193,6 +193,15 @@ void init_array(nb::module_& m) {
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
nb::class_<ArrayLike>(
|
||||
m,
|
||||
"ArrayLike",
|
||||
R"pbdoc(
|
||||
Any Python object which has an ``__mlx__array__`` method that
|
||||
returns an :obj:`array`.
|
||||
)pbdoc")
|
||||
.def(nb::init_implicit<nb::object>());
|
||||
|
||||
nb::class_<ArrayPythonIterator>(
|
||||
m,
|
||||
"ArrayIterator",
|
||||
|
@ -477,7 +477,7 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||
return mx::astype(*pv, t.value_or((*pv).dtype()));
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
||||
auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
||||
return mx::astype(arr, t.value_or(arr.dtype()));
|
||||
}
|
||||
}
|
||||
|
@ -12,6 +12,11 @@
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
struct ArrayLike {
|
||||
ArrayLike(nb::object obj) : obj(obj) {};
|
||||
nb::object obj;
|
||||
};
|
||||
|
||||
using ArrayInitType = std::variant<
|
||||
nb::bool_,
|
||||
nb::int_,
|
||||
@ -23,7 +28,7 @@ using ArrayInitType = std::variant<
|
||||
std::complex<float>,
|
||||
nb::list,
|
||||
nb::tuple,
|
||||
nb::object>;
|
||||
ArrayLike>;
|
||||
|
||||
mx::array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
|
@ -28,7 +28,7 @@ mx::array to_array(
|
||||
pv) {
|
||||
return nd_array_to_mlx(*pv, dtype);
|
||||
} else {
|
||||
return to_array_with_accessor(std::get<nb::object>(v));
|
||||
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,14 +42,15 @@ std::pair<mx::array, mx::array> to_arrays(
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||
return std::holds_alternative<mx::array>(x) ||
|
||||
std::holds_alternative<nb::object>(x) &&
|
||||
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
||||
std::holds_alternative<ArrayLike>(x) &&
|
||||
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
||||
};
|
||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||
return *px;
|
||||
} else {
|
||||
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||
return nb::cast<mx::array>(
|
||||
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/convert.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
@ -25,7 +26,7 @@ using ScalarOrArray = std::variant<
|
||||
// Must be above complex
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||
std::complex<float>,
|
||||
nb::object>;
|
||||
ArrayLike>;
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
@ -43,8 +44,9 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
// Checks if the value can be compared to an array (or is already an
|
||||
// mlx array)
|
||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
||||
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||
if (auto pv = std::get_if<ArrayLike>(&v); pv) {
|
||||
auto obj = (*pv).obj;
|
||||
return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, "__mlx_array__");
|
||||
} else {
|
||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||
// and can be compared to an array
|
||||
@ -53,7 +55,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
}
|
||||
|
||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
||||
return std::get<nb::object>(v).ptr();
|
||||
return std::get<ArrayLike>(v).obj.ptr();
|
||||
}
|
||||
|
||||
inline void throw_invalid_operation(
|
||||
|
Loading…
Reference in New Issue
Block a user