mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	int() and float() for mx.array (#1360)
This commit is contained in:
		| @@ -840,6 +840,8 @@ void init_array(nb::module_& m) { | |||||||
|           }, |           }, | ||||||
|           "other"_a, |           "other"_a, | ||||||
|           nb::rv_policy::none) |           nb::rv_policy::none) | ||||||
|  |       .def("__int__", [](array& a) { return nb::int_(to_scalar(a)); }) | ||||||
|  |       .def("__float__", [](array& a) { return nb::float_(to_scalar(a)); }) | ||||||
|       .def( |       .def( | ||||||
|           "flatten", |           "flatten", | ||||||
|           [](const array& a, |           [](const array& a, | ||||||
|   | |||||||
| @@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) { | |||||||
| } | } | ||||||
|  |  | ||||||
| nb::object to_scalar(array& a) { | nb::object to_scalar(array& a) { | ||||||
|  |   if (a.size() != 1) { | ||||||
|  |     throw std::invalid_argument( | ||||||
|  |         "[convert] Only length-1 arrays can be converted to Python scalars."); | ||||||
|  |   } | ||||||
|   { |   { | ||||||
|     nb::gil_scoped_release nogil; |     nb::gil_scoped_release nogil; | ||||||
|     a.eval(); |     a.eval(); | ||||||
|   | |||||||
| @@ -1834,6 +1834,21 @@ class TestArray(mlx_tests.MLXTestCase): | |||||||
|         self.assertTrue(hasattr(api, "array")) |         self.assertTrue(hasattr(api, "array")) | ||||||
|         self.assertTrue(hasattr(api, "add")) |         self.assertTrue(hasattr(api, "add")) | ||||||
|  |  | ||||||
|  |     def test_to_scalar(self): | ||||||
|  |         a = mx.array(1) | ||||||
|  |         self.assertEqual(int(a), 1) | ||||||
|  |         self.assertEqual(float(a), 1) | ||||||
|  |  | ||||||
|  |         a = mx.array(1.5) | ||||||
|  |         self.assertEqual(float(a), 1.5) | ||||||
|  |         self.assertEqual(int(a), 1) | ||||||
|  |  | ||||||
|  |         a = mx.zeros((2, 1)) | ||||||
|  |         with self.assertRaises(ValueError): | ||||||
|  |             float(a) | ||||||
|  |         with self.assertRaises(ValueError): | ||||||
|  |             int(a) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron