From d1183821a76d60568845c7978cf663f2f645be10 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 25 Aug 2024 20:41:44 -0700 Subject: [PATCH] int() and float() for mx.array (#1360) --- python/src/array.cpp | 2 ++ python/src/convert.cpp | 4 ++++ python/tests/test_array.py | 15 +++++++++++++++ 3 files changed, 21 insertions(+) diff --git a/python/src/array.cpp b/python/src/array.cpp index 22b8f69c1..9213a9988 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -840,6 +840,8 @@ void init_array(nb::module_& m) { }, "other"_a, nb::rv_policy::none) + .def("__int__", [](array& a) { return nb::int_(to_scalar(a)); }) + .def("__float__", [](array& a) { return nb::float_(to_scalar(a)); }) .def( "flatten", [](const array& a, diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 45b295dab..b3554eef8 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) { } nb::object to_scalar(array& a) { + if (a.size() != 1) { + throw std::invalid_argument( + "[convert] Only length-1 arrays can be converted to Python scalars."); + } { nb::gil_scoped_release nogil; a.eval(); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 0144c34c5..0508d3362 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1834,6 +1834,21 @@ class TestArray(mlx_tests.MLXTestCase): self.assertTrue(hasattr(api, "array")) self.assertTrue(hasattr(api, "add")) + def test_to_scalar(self): + a = mx.array(1) + self.assertEqual(int(a), 1) + self.assertEqual(float(a), 1) + + a = mx.array(1.5) + self.assertEqual(float(a), 1.5) + self.assertEqual(int(a), 1) + + a = mx.zeros((2, 1)) + with self.assertRaises(ValueError): + float(a) + with self.assertRaises(ValueError): + int(a) + if __name__ == "__main__": unittest.main()