mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix conversion to array (#1070)
This commit is contained in:
		@@ -18,6 +18,7 @@ nanobind_add_module(
 | 
				
			|||||||
  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
 | 
				
			||||||
  ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
 | 
				
			||||||
  ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
 | 
				
			||||||
 | 
					  ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
 | 
					if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										80
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								python/src/utils.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,80 @@
 | 
				
			|||||||
 | 
					// Copyright © 2024 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "python/src/utils.h"
 | 
				
			||||||
 | 
					#include "mlx/ops.h"
 | 
				
			||||||
 | 
					#include "python/src/convert.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array to_array(
 | 
				
			||||||
 | 
					    const ScalarOrArray& v,
 | 
				
			||||||
 | 
					    std::optional<Dtype> dtype /* = std::nullopt */) {
 | 
				
			||||||
 | 
					  if (auto pv = std::get_if<nb::bool_>(&v); pv) {
 | 
				
			||||||
 | 
					    return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
 | 
				
			||||||
 | 
					  } else if (auto pv = std::get_if<nb::int_>(&v); pv) {
 | 
				
			||||||
 | 
					    auto out_t = dtype.value_or(int32);
 | 
				
			||||||
 | 
					    // bool_ is an exception and is always promoted
 | 
				
			||||||
 | 
					    return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
 | 
				
			||||||
 | 
					  } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
 | 
				
			||||||
 | 
					    auto out_t = dtype.value_or(float32);
 | 
				
			||||||
 | 
					    return array(
 | 
				
			||||||
 | 
					        nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
 | 
				
			||||||
 | 
					  } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
 | 
				
			||||||
 | 
					    return array(static_cast<complex64_t>(*pv), complex64);
 | 
				
			||||||
 | 
					  } else if (auto pv = std::get_if<array>(&v); pv) {
 | 
				
			||||||
 | 
					    return *pv;
 | 
				
			||||||
 | 
					  } else if (auto pv = std::get_if<
 | 
				
			||||||
 | 
					                 nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
 | 
				
			||||||
 | 
					             pv) {
 | 
				
			||||||
 | 
					    return nd_array_to_mlx(*pv, dtype);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    return to_array_with_accessor(std::get<nb::object>(v));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::pair<array, array> to_arrays(
 | 
				
			||||||
 | 
					    const ScalarOrArray& a,
 | 
				
			||||||
 | 
					    const ScalarOrArray& b) {
 | 
				
			||||||
 | 
					  // Four cases:
 | 
				
			||||||
 | 
					  // - If both a and b are arrays leave their types alone
 | 
				
			||||||
 | 
					  // - If a is an array but b is not, treat b as a weak python type
 | 
				
			||||||
 | 
					  // - If b is an array but a is not, treat a as a weak python type
 | 
				
			||||||
 | 
					  // - If neither is an array convert to arrays but leave their types alone
 | 
				
			||||||
 | 
					  auto is_mlx_array = [](const ScalarOrArray& x) {
 | 
				
			||||||
 | 
					    return std::holds_alternative<array>(x) ||
 | 
				
			||||||
 | 
					        std::holds_alternative<nb::object>(x) &&
 | 
				
			||||||
 | 
					        nb::hasattr(std::get<nb::object>(x), "__mlx_array__");
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  auto get_mlx_array = [](const ScalarOrArray& x) {
 | 
				
			||||||
 | 
					    if (auto px = std::get_if<array>(&x); px) {
 | 
				
			||||||
 | 
					      return *px;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      return nb::cast<array>(std::get<nb::object>(x).attr("__mlx_array__"));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (is_mlx_array(a)) {
 | 
				
			||||||
 | 
					    auto arr_a = get_mlx_array(a);
 | 
				
			||||||
 | 
					    if (is_mlx_array(b)) {
 | 
				
			||||||
 | 
					      auto arr_b = get_mlx_array(b);
 | 
				
			||||||
 | 
					      return {arr_a, arr_b};
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return {arr_a, to_array(b, arr_a.dtype())};
 | 
				
			||||||
 | 
					  } else if (is_mlx_array(b)) {
 | 
				
			||||||
 | 
					    auto arr_b = get_mlx_array(b);
 | 
				
			||||||
 | 
					    return {to_array(a, arr_b.dtype()), arr_b};
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    return {to_array(a), to_array(b)};
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					array to_array_with_accessor(nb::object obj) {
 | 
				
			||||||
 | 
					  if (nb::isinstance<array>(obj)) {
 | 
				
			||||||
 | 
					    return nb::cast<array>(obj);
 | 
				
			||||||
 | 
					  } else if (nb::hasattr(obj, "__mlx_array__")) {
 | 
				
			||||||
 | 
					    return nb::cast<array>(obj.attr("__mlx_array__")());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    std::ostringstream msg;
 | 
				
			||||||
 | 
					    msg << "Invalid type  " << nb::type_name(obj.type()).c_str()
 | 
				
			||||||
 | 
					        << " received in array initialization.";
 | 
				
			||||||
 | 
					    throw std::invalid_argument(msg.str());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -6,6 +6,7 @@
 | 
				
			|||||||
#include <variant>
 | 
					#include <variant>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <nanobind/nanobind.h>
 | 
					#include <nanobind/nanobind.h>
 | 
				
			||||||
 | 
					#include <nanobind/ndarray.h>
 | 
				
			||||||
#include <nanobind/stl/complex.h>
 | 
					#include <nanobind/stl/complex.h>
 | 
				
			||||||
#include <nanobind/stl/variant.h>
 | 
					#include <nanobind/stl/variant.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -16,8 +17,16 @@ namespace nb = nanobind;
 | 
				
			|||||||
using namespace mlx::core;
 | 
					using namespace mlx::core;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
 | 
					using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
 | 
				
			||||||
using ScalarOrArray = std::
 | 
					using ScalarOrArray = std::variant<
 | 
				
			||||||
    variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
 | 
					    nb::bool_,
 | 
				
			||||||
 | 
					    nb::int_,
 | 
				
			||||||
 | 
					    nb::float_,
 | 
				
			||||||
 | 
					    // Must be above ndarray
 | 
				
			||||||
 | 
					    array,
 | 
				
			||||||
 | 
					    // Must be above complex
 | 
				
			||||||
 | 
					    nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
 | 
				
			||||||
 | 
					    std::complex<float>,
 | 
				
			||||||
 | 
					    nb::object>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
 | 
					inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
 | 
				
			||||||
  std::vector<int> axes;
 | 
					  std::vector<int> axes;
 | 
				
			||||||
@@ -32,19 +41,6 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
 | 
				
			|||||||
  return axes;
 | 
					  return axes;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inline array to_array_with_accessor(nb::object obj) {
 | 
					 | 
				
			||||||
  if (nb::isinstance<array>(obj)) {
 | 
					 | 
				
			||||||
    return nb::cast<array>(obj);
 | 
					 | 
				
			||||||
  } else if (nb::hasattr(obj, "__mlx_array__")) {
 | 
					 | 
				
			||||||
    return nb::cast<array>(obj.attr("__mlx_array__")());
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    std::ostringstream msg;
 | 
					 | 
				
			||||||
    msg << "Invalid type  " << nb::type_name(obj.type()).c_str()
 | 
					 | 
				
			||||||
        << " received in array initialization.";
 | 
					 | 
				
			||||||
    throw std::invalid_argument(msg.str());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
 | 
					inline bool is_comparable_with_array(const ScalarOrArray& v) {
 | 
				
			||||||
  // Checks if the value can be compared to an array (or is already an
 | 
					  // Checks if the value can be compared to an array (or is already an
 | 
				
			||||||
  // mlx array)
 | 
					  // mlx array)
 | 
				
			||||||
@@ -70,45 +66,12 @@ inline void throw_invalid_operation(
 | 
				
			|||||||
  throw std::invalid_argument(msg.str());
 | 
					  throw std::invalid_argument(msg.str());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
inline array to_array(
 | 
					array to_array(
 | 
				
			||||||
    const ScalarOrArray& v,
 | 
					    const ScalarOrArray& v,
 | 
				
			||||||
    std::optional<Dtype> dtype = std::nullopt) {
 | 
					    std::optional<Dtype> dtype = std::nullopt);
 | 
				
			||||||
  if (auto pv = std::get_if<nb::bool_>(&v); pv) {
 | 
					 | 
				
			||||||
    return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
 | 
					 | 
				
			||||||
  } else if (auto pv = std::get_if<nb::int_>(&v); pv) {
 | 
					 | 
				
			||||||
    auto out_t = dtype.value_or(int32);
 | 
					 | 
				
			||||||
    // bool_ is an exception and is always promoted
 | 
					 | 
				
			||||||
    return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
 | 
					 | 
				
			||||||
  } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
 | 
					 | 
				
			||||||
    auto out_t = dtype.value_or(float32);
 | 
					 | 
				
			||||||
    return array(
 | 
					 | 
				
			||||||
        nb::cast<float>(*pv), issubdtype(out_t, floating) ? out_t : float32);
 | 
					 | 
				
			||||||
  } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
 | 
					 | 
				
			||||||
    return array(static_cast<complex64_t>(*pv), complex64);
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    return to_array_with_accessor(std::get<nb::object>(v));
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
inline std::pair<array, array> to_arrays(
 | 
					std::pair<array, array> to_arrays(
 | 
				
			||||||
    const ScalarOrArray& a,
 | 
					    const ScalarOrArray& a,
 | 
				
			||||||
    const ScalarOrArray& b) {
 | 
					    const ScalarOrArray& b);
 | 
				
			||||||
  // Four cases:
 | 
					
 | 
				
			||||||
  // - If both a and b are arrays leave their types alone
 | 
					array to_array_with_accessor(nb::object obj);
 | 
				
			||||||
  // - If a is an array but b is not, treat b as a weak python type
 | 
					 | 
				
			||||||
  // - If b is an array but a is not, treat a as a weak python type
 | 
					 | 
				
			||||||
  // - If neither is an array convert to arrays but leave their types alone
 | 
					 | 
				
			||||||
  if (auto pa = std::get_if<nb::object>(&a); pa) {
 | 
					 | 
				
			||||||
    auto arr_a = to_array_with_accessor(*pa);
 | 
					 | 
				
			||||||
    if (auto pb = std::get_if<nb::object>(&b); pb) {
 | 
					 | 
				
			||||||
      auto arr_b = to_array_with_accessor(*pb);
 | 
					 | 
				
			||||||
      return {arr_a, arr_b};
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return {arr_a, to_array(b, arr_a.dtype())};
 | 
					 | 
				
			||||||
  } else if (auto pb = std::get_if<nb::object>(&b); pb) {
 | 
					 | 
				
			||||||
    auto arr_b = to_array_with_accessor(*pb);
 | 
					 | 
				
			||||||
    return {to_array(a, arr_b.dtype()), arr_b};
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    return {to_array(a), to_array(b)};
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1710,6 +1710,13 @@ class TestArray(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        peak_2 = mx.metal.get_peak_memory()
 | 
					        peak_2 = mx.metal.get_peak_memory()
 | 
				
			||||||
        self.assertEqual(peak_1, peak_2)
 | 
					        self.assertEqual(peak_1, peak_2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_add_numpy(self):
 | 
				
			||||||
 | 
					        x = mx.array(1)
 | 
				
			||||||
 | 
					        y = np.array(2, dtype=np.int32)
 | 
				
			||||||
 | 
					        z = x + y
 | 
				
			||||||
 | 
					        self.assertEqual(z.dtype, mx.int32)
 | 
				
			||||||
 | 
					        self.assertEqual(z.item(), 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main()
 | 
					    unittest.main()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user