Hashable dtype and mlx.core prefixed repr (#89)

* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
This commit is contained in:
Angelos Katharopoulos
2023-12-09 09:35:28 -08:00
committed by GitHub
parent 976e8babbe
commit fd836d891b
2 changed files with 18 additions and 14 deletions

View File

@@ -436,10 +436,14 @@ void init_array(py::module_& m) {
"__repr__",
[](const Dtype& t) {
std::ostringstream os;
os << "mlx.core.";
os << t;
return os.str();
})
.def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; });
.def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; })
.def("__hash__", [](const Dtype& t) {
return static_cast<int64_t>(t.val);
});
m.attr("bool_") = py::cast(bool_);
m.attr("uint8") = py::cast(uint8);
m.attr("uint16") = py::cast(uint16);