mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix creating array from bf16 tensors in jax / torch (#1305)
This commit is contained in:
parent
6c8dd307eb
commit
10b5835501
@ -24,15 +24,6 @@ struct ndarray_traits<float16_t> {
|
|||||||
static constexpr bool is_signed = true;
|
static constexpr bool is_signed = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct ndarray_traits<bfloat16_t> {
|
|
||||||
static constexpr bool is_complex = false;
|
|
||||||
static constexpr bool is_float = true;
|
|
||||||
static constexpr bool is_bool = false;
|
|
||||||
static constexpr bool is_int = false;
|
|
||||||
static constexpr bool is_signed = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||||
}; // namespace nanobind
|
}; // namespace nanobind
|
||||||
|
|
||||||
@ -88,7 +79,7 @@ array nd_array_to_mlx(
|
|||||||
} else if (type == nb::dtype<float16_t>()) {
|
} else if (type == nb::dtype<float16_t>()) {
|
||||||
return nd_array_to_mlx_contiguous<float16_t>(
|
return nd_array_to_mlx_contiguous<float16_t>(
|
||||||
nd_array, shape, dtype.value_or(float16));
|
nd_array, shape, dtype.value_or(float16));
|
||||||
} else if (type == nb::dtype<bfloat16_t>()) {
|
} else if (type == nb::bfloat16) {
|
||||||
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
||||||
nd_array, shape, dtype.value_or(bfloat16));
|
nd_array, shape, dtype.value_or(bfloat16));
|
||||||
} else if (type == nb::dtype<float>()) {
|
} else if (type == nb::dtype<float>()) {
|
||||||
|
@ -183,6 +183,14 @@ class TestBF16(mlx_tests.MLXTestCase):
|
|||||||
]:
|
]:
|
||||||
test_blas(shape_x, shape_y)
|
test_blas(shape_x, shape_y)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires PyTorch")
|
||||||
|
def test_conversion(self):
|
||||||
|
a_torch = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16)
|
||||||
|
a_mx = mx.array(a_torch)
|
||||||
|
expected = mx.array([1.0, 2.0, 3.0], mx.bfloat16)
|
||||||
|
self.assertEqual(a_mx.dtype, mx.bfloat16)
|
||||||
|
self.assertTrue(mx.array_equal(a_mx, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user