mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable * Add mlx.core prefix to our dtypes' repr * Update the dtype test
This commit is contained in:
parent
976e8babbe
commit
fd836d891b
@ -436,10 +436,14 @@ void init_array(py::module_& m) {
|
||||
"__repr__",
|
||||
[](const Dtype& t) {
|
||||
std::ostringstream os;
|
||||
os << "mlx.core.";
|
||||
os << t;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; });
|
||||
.def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; })
|
||||
.def("__hash__", [](const Dtype& t) {
|
||||
return static_cast<int64_t>(t.val);
|
||||
});
|
||||
m.attr("bool_") = py::cast(bool_);
|
||||
m.attr("uint8") = py::cast(uint8);
|
||||
m.attr("uint16") = py::cast(uint16);
|
||||
|
@ -34,19 +34,19 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.bfloat16.size, 2)
|
||||
self.assertEqual(mx.complex64.size, 8)
|
||||
|
||||
self.assertEqual(str(mx.bool_), "bool")
|
||||
self.assertEqual(str(mx.uint8), "uint8")
|
||||
self.assertEqual(str(mx.uint16), "uint16")
|
||||
self.assertEqual(str(mx.uint32), "uint32")
|
||||
self.assertEqual(str(mx.uint64), "uint64")
|
||||
self.assertEqual(str(mx.int8), "int8")
|
||||
self.assertEqual(str(mx.int16), "int16")
|
||||
self.assertEqual(str(mx.int32), "int32")
|
||||
self.assertEqual(str(mx.int64), "int64")
|
||||
self.assertEqual(str(mx.float16), "float16")
|
||||
self.assertEqual(str(mx.float32), "float32")
|
||||
self.assertEqual(str(mx.bfloat16), "bfloat16")
|
||||
self.assertEqual(str(mx.complex64), "complex64")
|
||||
self.assertEqual(str(mx.bool_), "mlx.core.bool")
|
||||
self.assertEqual(str(mx.uint8), "mlx.core.uint8")
|
||||
self.assertEqual(str(mx.uint16), "mlx.core.uint16")
|
||||
self.assertEqual(str(mx.uint32), "mlx.core.uint32")
|
||||
self.assertEqual(str(mx.uint64), "mlx.core.uint64")
|
||||
self.assertEqual(str(mx.int8), "mlx.core.int8")
|
||||
self.assertEqual(str(mx.int16), "mlx.core.int16")
|
||||
self.assertEqual(str(mx.int32), "mlx.core.int32")
|
||||
self.assertEqual(str(mx.int64), "mlx.core.int64")
|
||||
self.assertEqual(str(mx.float16), "mlx.core.float16")
|
||||
self.assertEqual(str(mx.float32), "mlx.core.float32")
|
||||
self.assertEqual(str(mx.bfloat16), "mlx.core.bfloat16")
|
||||
self.assertEqual(str(mx.complex64), "mlx.core.complex64")
|
||||
|
||||
def test_scalar_conversion(self):
|
||||
dtypes = [
|
||||
|
Loading…
Reference in New Issue
Block a user