mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 13:54:44 +08:00
@@ -142,12 +142,13 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
||||
case float16:
|
||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||
case bfloat16:
|
||||
throw nb::type_error(
|
||||
"bfloat16 arrays cannot be converted directly to NumPy.");
|
||||
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
|
||||
case float32:
|
||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||
case complex64:
|
||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||
default:
|
||||
throw nb::type_error("type cannot be converted to NumPy.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,6 +196,8 @@ nb::object to_scalar(array& a) {
|
||||
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||
case complex64:
|
||||
return nb::cast(a.item<std::complex<float>>());
|
||||
default:
|
||||
throw nb::type_error("type cannot be converted to Python scalar.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,6 +251,8 @@ nb::object tolist(array& a) {
|
||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||
case complex64:
|
||||
return to_list<std::complex<float>>(a, 0, 0);
|
||||
default:
|
||||
throw nb::type_error("data type cannot be converted to Python list.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1308,8 +1308,8 @@ void init_ops(nb::module_& m) {
|
||||
"start"_a,
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
@@ -1356,8 +1356,8 @@ void init_ops(nb::module_& m) {
|
||||
},
|
||||
"stop"_a,
|
||||
"step"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"dtype"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||
|
Reference in New Issue
Block a user