From 350095ce6ea222df7c39ed26891d433977f90a13 Mon Sep 17 00:00:00 2001 From: mutexuan Date: Tue, 2 Jan 2024 11:02:04 +0800 Subject: [PATCH] fix type cast error in item() for bfloat16 (#339) Co-authored-by: xuan --- python/src/array.cpp | 2 +- python/tests/test_array.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 2580de0ea..f8a1a27cd 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -64,7 +64,7 @@ auto to_scalar(array& a) { case float32: return py::cast(a.item(retain_graph)); case bfloat16: - return py::cast(static_cast(a.item(retain_graph))); + return py::cast(static_cast(a.item(retain_graph))); case complex64: return py::cast(a.item>(retain_graph)); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 847d6d142..b6471cdbd 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -102,6 +102,9 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(x.item(), 1) self.assertTrue(isinstance(x.item(), int)) + x = mx.array(1, mx.bfloat16) + self.assertEqual(x.item(), 1.0) + x = mx.array(1.0) self.assertEqual(x.size, 1) self.assertEqual(x.ndim, 0)