diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 9c4d71b1b..45b295dab 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -24,15 +24,6 @@ struct ndarray_traits { static constexpr bool is_signed = true; }; -template <> -struct ndarray_traits { - static constexpr bool is_complex = false; - static constexpr bool is_float = true; - static constexpr bool is_bool = false; - static constexpr bool is_int = false; - static constexpr bool is_signed = true; -}; - static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind @@ -88,7 +79,7 @@ array nd_array_to_mlx( } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(float16)); - } else if (type == nb::dtype()) { + } else if (type == nb::bfloat16) { return nd_array_to_mlx_contiguous( nd_array, shape, dtype.value_or(bfloat16)); } else if (type == nb::dtype()) { diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 50afa73fc..0b4b49919 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -183,6 +183,14 @@ class TestBF16(mlx_tests.MLXTestCase): ]: test_blas(shape_x, shape_y) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_conversion(self): + a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16) + a_mx = mx.array(a_torch) + expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16) + self.assertEqual(a_mx.dtype, mx.bfloat16) + self.assertTrue(mx.array_equal(a_mx, expected)) + if __name__ == "__main__": unittest.main()