A few updates for CPU (#1482)

* some updates

* format

* fix

* nit
This commit is contained in:
Awni Hannun
2024-10-14 12:45:49 -07:00
committed by GitHub
parent 881615b072
commit 020f048cd0
6 changed files with 50 additions and 25 deletions

View File

@@ -142,12 +142,13 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
case float16:
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
case bfloat16:
throw nb::type_error(
"bfloat16 arrays cannot be converted directly to NumPy.");
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case float32:
return mlx_to_nd_array_impl<float, NDParams...>(a);
case complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default:
throw nb::type_error("type cannot be converted to NumPy.");
}
}
@@ -195,6 +196,8 @@ nb::object to_scalar(array& a) {
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
return nb::cast(a.item<std::complex<float>>());
default:
throw nb::type_error("type cannot be converted to Python scalar.");
}
}
@@ -248,6 +251,8 @@ nb::object tolist(array& a) {
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
return to_list<std::complex<float>>(a, 0, 0);
default:
throw nb::type_error("data type cannot be converted to Python list.");
}
}

View File

@@ -1308,8 +1308,8 @@ void init_ops(nb::module_& m) {
"start"_a,
"stop"_a,
"step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
@@ -1356,8 +1356,8 @@ void init_ops(nb::module_& m) {
},
"stop"_a,
"step"_a = nb::none(),
nb::kw_only(),
"dtype"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));