mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -11,6 +11,7 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/convert.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
@@ -25,7 +26,7 @@ using ScalarOrArray = std::variant<
|
||||
// Must be above complex
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||
std::complex<float>,
|
||||
nb::object>;
|
||||
ArrayLike>;
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
@@ -43,8 +44,9 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
// Checks if the value can be compared to an array (or is already an
|
||||
// mlx array)
|
||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
||||
return nb::isinstance<mx::array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||
if (auto pv = std::get_if<ArrayLike>(&v); pv) {
|
||||
auto obj = (*pv).obj;
|
||||
return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, "__mlx_array__");
|
||||
} else {
|
||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||
// and can be compared to an array
|
||||
@@ -53,7 +55,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
}
|
||||
|
||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
||||
return std::get<nb::object>(v).ptr();
|
||||
return std::get<ArrayLike>(v).obj.ptr();
|
||||
}
|
||||
|
||||
inline void throw_invalid_operation(
|
||||
|
||||
Reference in New Issue
Block a user