mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	int() and float() for mx.array (#1360)
This commit is contained in:
		@@ -840,6 +840,8 @@ void init_array(nb::module_& m) {
 | 
			
		||||
          },
 | 
			
		||||
          "other"_a,
 | 
			
		||||
          nb::rv_policy::none)
 | 
			
		||||
      .def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
 | 
			
		||||
      .def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
 | 
			
		||||
      .def(
 | 
			
		||||
          "flatten",
 | 
			
		||||
          [](const array& a,
 | 
			
		||||
 
 | 
			
		||||
@@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
nb::object to_scalar(array& a) {
 | 
			
		||||
  if (a.size() != 1) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[convert] Only length-1 arrays can be converted to Python scalars.");
 | 
			
		||||
  }
 | 
			
		||||
  {
 | 
			
		||||
    nb::gil_scoped_release nogil;
 | 
			
		||||
    a.eval();
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user