mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable * Add mlx.core prefix to our dtypes' repr * Update the dtype test
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							976e8babbe
						
					
				
				
					commit
					fd836d891b
				
			| @@ -436,10 +436,14 @@ void init_array(py::module_& m) { | ||||
|           "__repr__", | ||||
|           [](const Dtype& t) { | ||||
|             std::ostringstream os; | ||||
|             os << "mlx.core."; | ||||
|             os << t; | ||||
|             return os.str(); | ||||
|           }) | ||||
|       .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }); | ||||
|       .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }) | ||||
|       .def("__hash__", [](const Dtype& t) { | ||||
|         return static_cast<int64_t>(t.val); | ||||
|       }); | ||||
|   m.attr("bool_") = py::cast(bool_); | ||||
|   m.attr("uint8") = py::cast(uint8); | ||||
|   m.attr("uint16") = py::cast(uint16); | ||||
|   | ||||
| @@ -34,19 +34,19 @@ class TestDtypes(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(mx.bfloat16.size, 2) | ||||
|         self.assertEqual(mx.complex64.size, 8) | ||||
|  | ||||
|         self.assertEqual(str(mx.bool_), "bool") | ||||
|         self.assertEqual(str(mx.uint8), "uint8") | ||||
|         self.assertEqual(str(mx.uint16), "uint16") | ||||
|         self.assertEqual(str(mx.uint32), "uint32") | ||||
|         self.assertEqual(str(mx.uint64), "uint64") | ||||
|         self.assertEqual(str(mx.int8), "int8") | ||||
|         self.assertEqual(str(mx.int16), "int16") | ||||
|         self.assertEqual(str(mx.int32), "int32") | ||||
|         self.assertEqual(str(mx.int64), "int64") | ||||
|         self.assertEqual(str(mx.float16), "float16") | ||||
|         self.assertEqual(str(mx.float32), "float32") | ||||
|         self.assertEqual(str(mx.bfloat16), "bfloat16") | ||||
|         self.assertEqual(str(mx.complex64), "complex64") | ||||
|         self.assertEqual(str(mx.bool_), "mlx.core.bool") | ||||
|         self.assertEqual(str(mx.uint8), "mlx.core.uint8") | ||||
|         self.assertEqual(str(mx.uint16), "mlx.core.uint16") | ||||
|         self.assertEqual(str(mx.uint32), "mlx.core.uint32") | ||||
|         self.assertEqual(str(mx.uint64), "mlx.core.uint64") | ||||
|         self.assertEqual(str(mx.int8), "mlx.core.int8") | ||||
|         self.assertEqual(str(mx.int16), "mlx.core.int16") | ||||
|         self.assertEqual(str(mx.int32), "mlx.core.int32") | ||||
|         self.assertEqual(str(mx.int64), "mlx.core.int64") | ||||
|         self.assertEqual(str(mx.float16), "mlx.core.float16") | ||||
|         self.assertEqual(str(mx.float32), "mlx.core.float32") | ||||
|         self.assertEqual(str(mx.bfloat16), "mlx.core.bfloat16") | ||||
|         self.assertEqual(str(mx.complex64), "mlx.core.complex64") | ||||
|  | ||||
|     def test_scalar_conversion(self): | ||||
|         dtypes = [ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user