mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
parent
a6b426422e
commit
f110357aaa
@ -85,7 +85,7 @@ jobs:
|
|||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install nanobind==2.4.0
|
||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
@ -137,7 +137,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install nanobind==2.4.0
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install tensorflow
|
pip install tensorflow
|
||||||
@ -226,7 +226,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install nanobind==2.4.0
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
@ -291,7 +291,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install nanobind==2.4.0
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
|
@ -239,8 +239,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
|||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
@ -18,8 +18,7 @@ find_package(
|
|||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
# ----------------------------- Extensions -----------------------------
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
@ -3,6 +3,6 @@ requires = [
|
|||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25",
|
||||||
"mlx>=0.18.0",
|
"mlx>=0.18.0",
|
||||||
"nanobind==2.2.0",
|
"nanobind==2.4.0",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"nanobind==2.2.0",
|
"nanobind==2.4.0",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
@ -193,6 +193,15 @@ void init_array(nb::module_& m) {
|
|||||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||||
|
|
||||||
|
nb::class_<ArrayLike>(
|
||||||
|
m,
|
||||||
|
"ArrayLike",
|
||||||
|
R"pbdoc(
|
||||||
|
Any Python object which has an ``__mlx__array__`` method that
|
||||||
|
returns an :obj:`array`.
|
||||||
|
)pbdoc")
|
||||||
|
.def(nb::init_implicit<nb::object>());
|
||||||
|
|
||||||
nb::class_<ArrayPythonIterator>(
|
nb::class_<ArrayPythonIterator>(
|
||||||
m,
|
m,
|
||||||
"ArrayIterator",
|
"ArrayIterator",
|
||||||
|
@ -477,7 +477,7 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
|||||||
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
|
||||||
return mx::astype(*pv, t.value_or((*pv).dtype()));
|
return mx::astype(*pv, t.value_or((*pv).dtype()));
|
||||||
} else {
|
} else {
|
||||||
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
||||||
return mx::astype(arr, t.value_or(arr.dtype()));
|
return mx::astype(arr, t.value_or(arr.dtype()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,11 @@
|
|||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
||||||
|
struct ArrayLike {
|
||||||
|
ArrayLike(nb::object obj) : obj(obj) {};
|
||||||
|
nb::object obj;
|
||||||
|
};
|
||||||
|
|
||||||
using ArrayInitType = std::variant<
|
using ArrayInitType = std::variant<
|
||||||
nb::bool_,
|
nb::bool_,
|
||||||
nb::int_,
|
nb::int_,
|
||||||
@ -23,7 +28,7 @@ using ArrayInitType = std::variant<
|
|||||||
std::complex<float>,
|
std::complex<float>,
|
||||||
nb::list,
|
nb::list,
|
||||||
nb::tuple,
|
nb::tuple,
|
||||||
nb::object>;
|
ArrayLike>;
|
||||||
|
|
||||||
mx::array nd_array_to_mlx(
|
mx::array nd_array_to_mlx(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
|
@ -28,7 +28,7 @@ mx::array to_array(
|
|||||||
pv) {
|
pv) {
|
||||||
return nd_array_to_mlx(*pv, dtype);
|
return nd_array_to_mlx(*pv, dtype);
|
||||||
} else {
|
} else {
|
||||||
return to_array_with_accessor(std::get<nb::object>(v));
|
return to_array_with_accessor(std::get<ArrayLike>(v).obj);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,14 +42,15 @@ std::pair<mx::array, mx::array> to_arrays(
|
|||||||
// - If neither is an array convert to arrays but leave their types alone
|
// - If neither is an array convert to arrays but leave their types alone
|
||||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||||
return std::holds_alternative<mx::array>(x) ||
|
return std::holds_alternative<mx::array>(x) ||
|
||||||
std::holds_alternative<nb::object>(x) &&
|
std::holds_alternative<ArrayLike>(x) &&
|
||||||
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
||||||
};
|
};
|
||||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||||
return *px;
|
return *px;
|
||||||
} else {
|
} else {
|
||||||
return nb::cast<mx::array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
return nb::cast<mx::array>(
|
||||||
|
std::get<ArrayLike>(x).obj.attr("__mlx_array__"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "python/src/convert.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -25,7 +26,7 @@ using ScalarOrArray = std::variant<
|
|||||||
// Must be above complex
|
// Must be above complex
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||||
std::complex<float>,
|
std::complex<float>,
|
||||||
nb::object>;
|
ArrayLike>;
|
||||||
|
|
||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||||
std::vector<int> axes;
|
std::vector<int> axes;
|
||||||
@ -43,8 +44,9 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|||||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||||
// Checks if the value can be compared to an array (or is already an
|
// Checks if the value can be compared to an array (or is already an
|
||||||
// mlx array)
|
// mlx array)
|
||||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
if (auto pv = std::get_if<ArrayLike>(&v); pv) {
|
||||||
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
auto obj = (*pv).obj;
|
||||||
|
return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, "__mlx_array__");
|
||||||
} else {
|
} else {
|
||||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||||
// and can be compared to an array
|
// and can be compared to an array
|
||||||
@ -53,7 +55,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
||||||
return std::get<nb::object>(v).ptr();
|
return std::get<ArrayLike>(v).obj.ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void throw_invalid_operation(
|
inline void throw_invalid_operation(
|
||||||
|
Loading…
Reference in New Issue
Block a user