mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
fix conversion to array (#1070)
This commit is contained in:
parent
6992498e7a
commit
9814a2ae12
@ -18,6 +18,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
80
python/src/utils.cpp
Normal file
80
python/src/utils.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "python/src/convert.h"
|
||||||
|
|
||||||
|
array to_array(
|
||||||
|
const ScalarOrArray& v,
|
||||||
|
std::optional<Dtype> dtype /* = std::nullopt */) {
|
||||||
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||||
|
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
|
||||||
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||||
|
auto out_t = dtype.value_or(int32);
|
||||||
|
// bool_ is an exception and is always promoted
|
||||||
|
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||||
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||||
|
auto out_t = dtype.value_or(float32);
|
||||||
|
return array(
|
||||||
|
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
||||||
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
|
return array(static_cast<complex64_t>(*pv), complex64);
|
||||||
|
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||||
|
return *pv;
|
||||||
|
} else if (auto pv = std::get_if<
|
||||||
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||||
|
pv) {
|
||||||
|
return nd_array_to_mlx(*pv, dtype);
|
||||||
|
} else {
|
||||||
|
return to_array_with_accessor(std::get<nb::object>(v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> to_arrays(
|
||||||
|
const ScalarOrArray& a,
|
||||||
|
const ScalarOrArray& b) {
|
||||||
|
// Four cases:
|
||||||
|
// - If both a and b are arrays leave their types alone
|
||||||
|
// - If a is an array but b is not, treat b as a weak python type
|
||||||
|
// - If b is an array but a is not, treat a as a weak python type
|
||||||
|
// - If neither is an array convert to arrays but leave their types alone
|
||||||
|
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||||
|
return std::holds_alternative<array>(x) ||
|
||||||
|
std::holds_alternative<nb::object>(x) &&
|
||||||
|
nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
|
||||||
|
};
|
||||||
|
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||||
|
if (auto px = std::get_if<array>(&x); px) {
|
||||||
|
return *px;
|
||||||
|
} else {
|
||||||
|
return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (is_mlx_array(a)) {
|
||||||
|
auto arr_a = get_mlx_array(a);
|
||||||
|
if (is_mlx_array(b)) {
|
||||||
|
auto arr_b = get_mlx_array(b);
|
||||||
|
return {arr_a, arr_b};
|
||||||
|
}
|
||||||
|
return {arr_a, to_array(b, arr_a.dtype())};
|
||||||
|
} else if (is_mlx_array(b)) {
|
||||||
|
auto arr_b = get_mlx_array(b);
|
||||||
|
return {to_array(a, arr_b.dtype()), arr_b};
|
||||||
|
} else {
|
||||||
|
return {to_array(a), to_array(b)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array to_array_with_accessor(nb::object obj) {
|
||||||
|
if (nb::isinstance<array>(obj)) {
|
||||||
|
return nb::cast<array>(obj);
|
||||||
|
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
||||||
|
return nb::cast<array>(obj.attr("__mlx_array__")());
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||||
|
<< " received in array initialization.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@
|
|||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/ndarray.h>
|
||||||
#include <nanobind/stl/complex.h>
|
#include <nanobind/stl/complex.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
@ -16,8 +17,16 @@ namespace nb = nanobind;
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||||
using ScalarOrArray = std::
|
using ScalarOrArray = std::variant<
|
||||||
variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
|
nb::bool_,
|
||||||
|
nb::int_,
|
||||||
|
nb::float_,
|
||||||
|
// Must be above ndarray
|
||||||
|
array,
|
||||||
|
// Must be above complex
|
||||||
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||||
|
std::complex<float>,
|
||||||
|
nb::object>;
|
||||||
|
|
||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||||
std::vector<int> axes;
|
std::vector<int> axes;
|
||||||
@ -32,19 +41,6 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|||||||
return axes;
|
return axes;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array to_array_with_accessor(nb::object obj) {
|
|
||||||
if (nb::isinstance<array>(obj)) {
|
|
||||||
return nb::cast<array>(obj);
|
|
||||||
} else if (nb::hasattr(obj, "__mlx_array__")) {
|
|
||||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
|
||||||
} else {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
|
||||||
<< " received in array initialization.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||||
// Checks if the value can be compared to an array (or is already an
|
// Checks if the value can be compared to an array (or is already an
|
||||||
// mlx array)
|
// mlx array)
|
||||||
@ -70,45 +66,12 @@ inline void throw_invalid_operation(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline array to_array(
|
array to_array(
|
||||||
const ScalarOrArray& v,
|
const ScalarOrArray& v,
|
||||||
std::optional<Dtype> dtype = std::nullopt) {
|
std::optional<Dtype> dtype = std::nullopt);
|
||||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
|
||||||
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
|
|
||||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
|
||||||
auto out_t = dtype.value_or(int32);
|
|
||||||
// bool_ is an exception and is always promoted
|
|
||||||
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
|
||||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
|
||||||
auto out_t = dtype.value_or(float32);
|
|
||||||
return array(
|
|
||||||
nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
|
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
|
||||||
return array(static_cast<complex64_t>(*pv), complex64);
|
|
||||||
} else {
|
|
||||||
return to_array_with_accessor(std::get<nb::object>(v));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::pair<array, array> to_arrays(
|
std::pair<array, array> to_arrays(
|
||||||
const ScalarOrArray& a,
|
const ScalarOrArray& a,
|
||||||
const ScalarOrArray& b) {
|
const ScalarOrArray& b);
|
||||||
// Four cases:
|
|
||||||
// - If both a and b are arrays leave their types alone
|
array to_array_with_accessor(nb::object obj);
|
||||||
// - If a is an array but b is not, treat b as a weak python type
|
|
||||||
// - If b is an array but a is not, treat a as a weak python type
|
|
||||||
// - If neither is an array convert to arrays but leave their types alone
|
|
||||||
if (auto pa = std::get_if<nb::object>(&a); pa) {
|
|
||||||
auto arr_a = to_array_with_accessor(*pa);
|
|
||||||
if (auto pb = std::get_if<nb::object>(&b); pb) {
|
|
||||||
auto arr_b = to_array_with_accessor(*pb);
|
|
||||||
return {arr_a, arr_b};
|
|
||||||
}
|
|
||||||
return {arr_a, to_array(b, arr_a.dtype())};
|
|
||||||
} else if (auto pb = std::get_if<nb::object>(&b); pb) {
|
|
||||||
auto arr_b = to_array_with_accessor(*pb);
|
|
||||||
return {to_array(a, arr_b.dtype()), arr_b};
|
|
||||||
} else {
|
|
||||||
return {to_array(a), to_array(b)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1710,6 +1710,13 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
peak_2 = mx.metal.get_peak_memory()
|
peak_2 = mx.metal.get_peak_memory()
|
||||||
self.assertEqual(peak_1, peak_2)
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
|
def test_add_numpy(self):
|
||||||
|
x = mx.array(1)
|
||||||
|
y = np.array(2, dtype=np.int32)
|
||||||
|
z = x + y
|
||||||
|
self.assertEqual(z.dtype, mx.int32)
|
||||||
|
self.assertEqual(z.item(), 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user