diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 83acef1e2..ae0531385 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -18,6 +18,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/utils.cpp b/python/src/utils.cpp new file mode 100644 index 000000000..5d1118b80 --- /dev/null +++ b/python/src/utils.cpp @@ -0,0 +1,80 @@ +// Copyright © 2024 Apple Inc. + +#include "python/src/utils.h" +#include "mlx/ops.h" +#include "python/src/convert.h" + +array to_array( + const ScalarOrArray& v, + std::optional dtype /* = std::nullopt */) { + if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), dtype.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { + auto out_t = dtype.value_or(int32); + // bool_ is an exception and is always promoted + return array(nb::cast(*pv), (out_t == bool_) ? int32 : out_t); + } else if (auto pv = std::get_if(&v); pv) { + auto out_t = dtype.value_or(float32); + return array( + nb::cast(*pv), issubdtype(out_t, floating) ? out_t : float32); + } else if (auto pv = std::get_if>(&v); pv) { + return array(static_cast(*pv), complex64); + } else if (auto pv = std::get_if(&v); pv) { + return *pv; + } else if (auto pv = std::get_if< + nb::ndarray>(&v); + pv) { + return nd_array_to_mlx(*pv, dtype); + } else { + return to_array_with_accessor(std::get(v)); + } +} + +std::pair to_arrays( + const ScalarOrArray& a, + const ScalarOrArray& b) { + // Four cases: + // - If both a and b are arrays leave their types alone + // - If a is an array but b is not, treat b as a weak python type + // - If b is an array but a is not, treat a as a weak python type + // - If neither is an array convert to arrays but leave their types alone + auto is_mlx_array = [](const ScalarOrArray& x) { + return std::holds_alternative(x) || + std::holds_alternative(x) && + nb::hasattr(std::get(x), "__mlx_array__"); + }; + auto get_mlx_array = [](const ScalarOrArray& x) { + if (auto px = std::get_if(&x); px) { + return *px; + } else { + return nb::cast(std::get(x).attr("__mlx_array__")); + } + }; + + if (is_mlx_array(a)) { + auto arr_a = get_mlx_array(a); + if (is_mlx_array(b)) { + auto arr_b = get_mlx_array(b); + return {arr_a, arr_b}; + } + return {arr_a, to_array(b, arr_a.dtype())}; + } else if (is_mlx_array(b)) { + auto arr_b = get_mlx_array(b); + return {to_array(a, arr_b.dtype()), arr_b}; + } else { + return {to_array(a), to_array(b)}; + } +} + +array to_array_with_accessor(nb::object obj) { + if (nb::isinstance(obj)) { + return nb::cast(obj); + } else if (nb::hasattr(obj, "__mlx_array__")) { + return nb::cast(obj.attr("__mlx_array__")()); + } else { + std::ostringstream msg; + msg << "Invalid type " << nb::type_name(obj.type()).c_str() + << " received in array initialization."; + throw std::invalid_argument(msg.str()); + } +} diff --git a/python/src/utils.h b/python/src/utils.h index bd29df2d6..3d5b1af97 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -16,8 +17,16 @@ namespace nb = nanobind; using namespace mlx::core; using IntOrVec = std::variant>; -using ScalarOrArray = std:: - variant, nb::object>; +using ScalarOrArray = std::variant< + nb::bool_, + nb::int_, + nb::float_, + // Must be above ndarray + array, + // Must be above complex + nb::ndarray, + std::complex, + nb::object>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; @@ -32,19 +41,6 @@ inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { return axes; } -inline array to_array_with_accessor(nb::object obj) { - if (nb::isinstance(obj)) { - return nb::cast(obj); - } else if (nb::hasattr(obj, "__mlx_array__")) { - return nb::cast(obj.attr("__mlx_array__")()); - } else { - std::ostringstream msg; - msg << "Invalid type " << nb::type_name(obj.type()).c_str() - << " received in array initialization."; - throw std::invalid_argument(msg.str()); - } -} - inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) @@ -70,45 +66,12 @@ inline void throw_invalid_operation( throw std::invalid_argument(msg.str()); } -inline array to_array( +array to_array( const ScalarOrArray& v, - std::optional dtype = std::nullopt) { - if (auto pv = std::get_if(&v); pv) { - return array(nb::cast(*pv), dtype.value_or(bool_)); - } else if (auto pv = std::get_if(&v); pv) { - auto out_t = dtype.value_or(int32); - // bool_ is an exception and is always promoted - return array(nb::cast(*pv), (out_t == bool_) ? int32 : out_t); - } else if (auto pv = std::get_if(&v); pv) { - auto out_t = dtype.value_or(float32); - return array( - nb::cast(*pv), issubdtype(out_t, floating) ? out_t : float32); - } else if (auto pv = std::get_if>(&v); pv) { - return array(static_cast(*pv), complex64); - } else { - return to_array_with_accessor(std::get(v)); - } -} + std::optional dtype = std::nullopt); -inline std::pair to_arrays( +std::pair to_arrays( const ScalarOrArray& a, - const ScalarOrArray& b) { - // Four cases: - // - If both a and b are arrays leave their types alone - // - If a is an array but b is not, treat b as a weak python type - // - If b is an array but a is not, treat a as a weak python type - // - If neither is an array convert to arrays but leave their types alone - if (auto pa = std::get_if(&a); pa) { - auto arr_a = to_array_with_accessor(*pa); - if (auto pb = std::get_if(&b); pb) { - auto arr_b = to_array_with_accessor(*pb); - return {arr_a, arr_b}; - } - return {arr_a, to_array(b, arr_a.dtype())}; - } else if (auto pb = std::get_if(&b); pb) { - auto arr_b = to_array_with_accessor(*pb); - return {to_array(a, arr_b.dtype()), arr_b}; - } else { - return {to_array(a), to_array(b)}; - } -} + const ScalarOrArray& b); + +array to_array_with_accessor(nb::object obj); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4faa3ec1c..fc97a833f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1710,6 +1710,13 @@ class TestArray(mlx_tests.MLXTestCase): peak_2 = mx.metal.get_peak_memory() self.assertEqual(peak_1, peak_2) + def test_add_numpy(self): + x = mx.array(1) + y = np.array(2, dtype=np.int32) + z = x + y + self.assertEqual(z.dtype, mx.int32) + self.assertEqual(z.item(), 3) + if __name__ == "__main__": unittest.main()