int() and float() for mx.array (#1360)

This commit is contained in:
Alex Barron 2024-08-25 20:41:44 -07:00 committed by GitHub
parent 8081df79be
commit d1183821a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 0 deletions

View File

@ -840,6 +840,8 @@ void init_array(nb::module_& m) {
},
"other"_a,
nb::rv_policy::none)
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
.def(
"flatten",
[](const array& a,

View File

@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) {
}
nb::object to_scalar(array& a) {
if (a.size() != 1) {
throw std::invalid_argument(
"[convert] Only length-1 arrays can be converted to Python scalars.");
}
{
nb::gil_scoped_release nogil;
a.eval();

View File

@ -1834,6 +1834,21 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertTrue(hasattr(api, "array"))
self.assertTrue(hasattr(api, "add"))
def test_to_scalar(self):
a = mx.array(1)
self.assertEqual(int(a), 1)
self.assertEqual(float(a), 1)
a = mx.array(1.5)
self.assertEqual(float(a), 1.5)
self.assertEqual(int(a), 1)
a = mx.zeros((2, 1))
with self.assertRaises(ValueError):
float(a)
with self.assertRaises(ValueError):
int(a)
if __name__ == "__main__":
unittest.main()