fix creating array from bf16 tensors in jax / torch (#1305)

This commit is contained in:
Awni Hannun
2024-08-01 16:20:51 -07:00
committed by GitHub
parent 6c8dd307eb
commit 10b5835501
2 changed files with 9 additions and 10 deletions

View File

@@ -24,15 +24,6 @@ struct ndarray_traits<float16_t> {
static constexpr bool is_signed = true;
};
template <>
struct ndarray_traits<bfloat16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
@@ -88,7 +79,7 @@ array nd_array_to_mlx(
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
} else if (type == nb::dtype<bfloat16_t>()) {
} else if (type == nb::bfloat16) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
} else if (type == nb::dtype<float>()) {