mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix creating array from bf16 tensors in jax / torch (#1305)
This commit is contained in:
		| @@ -24,15 +24,6 @@ struct ndarray_traits<float16_t> { | ||||
|   static constexpr bool is_signed = true; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct ndarray_traits<bfloat16_t> { | ||||
|   static constexpr bool is_complex = false; | ||||
|   static constexpr bool is_float = true; | ||||
|   static constexpr bool is_bool = false; | ||||
|   static constexpr bool is_int = false; | ||||
|   static constexpr bool is_signed = true; | ||||
| }; | ||||
|  | ||||
| static constexpr dlpack::dtype bfloat16{4, 16, 1}; | ||||
| }; // namespace nanobind | ||||
|  | ||||
| @@ -88,7 +79,7 @@ array nd_array_to_mlx( | ||||
|   } else if (type == nb::dtype<float16_t>()) { | ||||
|     return nd_array_to_mlx_contiguous<float16_t>( | ||||
|         nd_array, shape, dtype.value_or(float16)); | ||||
|   } else if (type == nb::dtype<bfloat16_t>()) { | ||||
|   } else if (type == nb::bfloat16) { | ||||
|     return nd_array_to_mlx_contiguous<bfloat16_t>( | ||||
|         nd_array, shape, dtype.value_or(bfloat16)); | ||||
|   } else if (type == nb::dtype<float>()) { | ||||
|   | ||||
| @@ -183,6 +183,14 @@ class TestBF16(mlx_tests.MLXTestCase): | ||||
|         ]: | ||||
|             test_blas(shape_x, shape_y) | ||||
|  | ||||
|     @unittest.skipIf(not has_torch, "requires PyTorch") | ||||
|     def test_conversion(self): | ||||
|         a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16) | ||||
|         a_mx = mx.array(a_torch) | ||||
|         expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16) | ||||
|         self.assertEqual(a_mx.dtype, mx.bfloat16) | ||||
|         self.assertTrue(mx.array_equal(a_mx, expected)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun