fix creating array from bf16 tensors in jax / torch (#1305)

This commit is contained in:
Awni Hannun
2024-08-01 16:20:51 -07:00
committed by GitHub
parent 6c8dd307eb
commit 10b5835501
2 changed files with 9 additions and 10 deletions

View File

@@ -183,6 +183,14 @@ class TestBF16(mlx_tests.MLXTestCase):
]:
test_blas(shape_x, shape_y)
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_conversion(self):
a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)
a_mx = mx.array(a_torch)
expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))
if __name__ == "__main__":
unittest.main()