diff --git a/python/src/array.cpp b/python/src/array.cpp index 9d6c93f13..86253f3f9 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -436,10 +436,14 @@ void init_array(py::module_& m) { "__repr__", [](const Dtype& t) { std::ostringstream os; + os << "mlx.core."; os << t; return os.str(); }) - .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }); + .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }) + .def("__hash__", [](const Dtype& t) { + return static_cast(t.val); + }); m.attr("bool_") = py::cast(bool_); m.attr("uint8") = py::cast(uint8); m.attr("uint16") = py::cast(uint16); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 3c38a7ef5..addec3493 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -34,19 +34,19 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.bfloat16.size, 2) self.assertEqual(mx.complex64.size, 8) - self.assertEqual(str(mx.bool_), "bool") - self.assertEqual(str(mx.uint8), "uint8") - self.assertEqual(str(mx.uint16), "uint16") - self.assertEqual(str(mx.uint32), "uint32") - self.assertEqual(str(mx.uint64), "uint64") - self.assertEqual(str(mx.int8), "int8") - self.assertEqual(str(mx.int16), "int16") - self.assertEqual(str(mx.int32), "int32") - self.assertEqual(str(mx.int64), "int64") - self.assertEqual(str(mx.float16), "float16") - self.assertEqual(str(mx.float32), "float32") - self.assertEqual(str(mx.bfloat16), "bfloat16") - self.assertEqual(str(mx.complex64), "complex64") + self.assertEqual(str(mx.bool_), "mlx.core.bool") + self.assertEqual(str(mx.uint8), "mlx.core.uint8") + self.assertEqual(str(mx.uint16), "mlx.core.uint16") + self.assertEqual(str(mx.uint32), "mlx.core.uint32") + self.assertEqual(str(mx.uint64), "mlx.core.uint64") + self.assertEqual(str(mx.int8), "mlx.core.int8") + self.assertEqual(str(mx.int16), "mlx.core.int16") + self.assertEqual(str(mx.int32), "mlx.core.int32") + self.assertEqual(str(mx.int64), "mlx.core.int64") + self.assertEqual(str(mx.float16), "mlx.core.float16") + self.assertEqual(str(mx.float32), "mlx.core.float32") + self.assertEqual(str(mx.bfloat16), "mlx.core.bfloat16") + self.assertEqual(str(mx.complex64), "mlx.core.complex64") def test_scalar_conversion(self): dtypes = [