Hashable dtype and mlx.core prefixed repr (#89)

* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
This commit is contained in:
Angelos Katharopoulos
2023-12-09 09:35:28 -08:00
committed by GitHub
parent 976e8babbe
commit fd836d891b
2 changed files with 18 additions and 14 deletions

View File

@@ -34,19 +34,19 @@ class TestDtypes(mlx_tests.MLXTestCase):
self.assertEqual(mx.bfloat16.size, 2)
self.assertEqual(mx.complex64.size, 8)
self.assertEqual(str(mx.bool_), "bool")
self.assertEqual(str(mx.uint8), "uint8")
self.assertEqual(str(mx.uint16), "uint16")
self.assertEqual(str(mx.uint32), "uint32")
self.assertEqual(str(mx.uint64), "uint64")
self.assertEqual(str(mx.int8), "int8")
self.assertEqual(str(mx.int16), "int16")
self.assertEqual(str(mx.int32), "int32")
self.assertEqual(str(mx.int64), "int64")
self.assertEqual(str(mx.float16), "float16")
self.assertEqual(str(mx.float32), "float32")
self.assertEqual(str(mx.bfloat16), "bfloat16")
self.assertEqual(str(mx.complex64), "complex64")
self.assertEqual(str(mx.bool_), "mlx.core.bool")
self.assertEqual(str(mx.uint8), "mlx.core.uint8")
self.assertEqual(str(mx.uint16), "mlx.core.uint16")
self.assertEqual(str(mx.uint32), "mlx.core.uint32")
self.assertEqual(str(mx.uint64), "mlx.core.uint64")
self.assertEqual(str(mx.int8), "mlx.core.int8")
self.assertEqual(str(mx.int16), "mlx.core.int16")
self.assertEqual(str(mx.int32), "mlx.core.int32")
self.assertEqual(str(mx.int64), "mlx.core.int64")
self.assertEqual(str(mx.float16), "mlx.core.float16")
self.assertEqual(str(mx.float32), "mlx.core.float32")
self.assertEqual(str(mx.bfloat16), "mlx.core.bfloat16")
self.assertEqual(str(mx.complex64), "mlx.core.complex64")
def test_scalar_conversion(self):
dtypes = [