diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 1fbc67c8e..eece43717 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,46 @@ namespace mlx::core { +void PrintFormatter::print(std::ostream& os, bool val) { + if (capitalize_bool) { + os << (val ? "True" : "False"); + } else { + os << val; + } +} +inline void PrintFormatter::print(std::ostream& os, int16_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, uint16_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, int32_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, uint32_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, int64_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, uint64_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, float16_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, float val) { + os << val; +} +inline void PrintFormatter::print(std::ostream& os, complex64_t val) { + os << val; +} + +PrintFormatter global_formatter; + Dtype result_type(const std::vector& arrays) { std::vector dtypes(1, bool_); for (auto& arr : arrays) { @@ -136,7 +176,7 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { i = n - num_print - 1; index += s * (n - 2 * num_print - 1); } else if (is_last) { - os << a.data()[index]; + global_formatter.print(os, a.data()[index]); } else { print_subarray(os, a, index, dim + 1); } @@ -153,7 +193,7 @@ void print_array(std::ostream& os, const array& a) { os << "array("; if (a.ndim() == 0) { auto data = a.data(); - os << data[0]; + global_formatter.print(os, data[0]); } else { print_subarray(os, a, 0, 0); } diff --git a/mlx/utils.h b/mlx/utils.h index 823b4c872..f28970369 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -9,6 +9,24 @@ namespace mlx::core { +struct PrintFormatter { + inline void print(std::ostream& os, bool val); + inline void print(std::ostream& os, int16_t val); + inline void print(std::ostream& os, uint16_t val); + inline void print(std::ostream& os, int32_t val); + inline void print(std::ostream& os, uint32_t val); + inline void print(std::ostream& os, int64_t val); + inline void print(std::ostream& os, uint64_t val); + inline void print(std::ostream& os, float16_t val); + inline void print(std::ostream& os, bfloat16_t val); + inline void print(std::ostream& os, float val); + inline void print(std::ostream& os, complex64_t val); + + bool capitalize_bool{false}; +}; + +extern PrintFormatter global_formatter; + /** The type from promoting the arrays' types with one another. */ Dtype result_type(const std::vector& arrays); diff --git a/python/src/array.cpp b/python/src/array.cpp index 407142fda..392eb34a2 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -520,6 +520,9 @@ class ArrayPythonIterator { }; void init_array(py::module_& m) { + // Set Python print formatting options + mlx::core::global_formatter.capitalize_bool = true; + // Types py::class_( m, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 776181e4f..94a9396b6 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -304,7 +304,7 @@ class TestArray(mlx_tests.MLXTestCase): def test_array_repr(self): x = mx.array(True) - self.assertEqual(str(x), "array(true, dtype=bool)") + self.assertEqual(str(x), "array(True, dtype=bool)") x = mx.array(1) self.assertEqual(str(x), "array(1, dtype=int32)") x = mx.array(1.0)