Double for lapack (#1904)

* double for lapack ops

* add double support for lapack ops
This commit is contained in:
Awni Hannun
2025-02-25 11:39:36 -08:00
committed by GitHub
parent 28b8079e30
commit 7d042f17fe
11 changed files with 338 additions and 225 deletions

View File

@@ -44,6 +44,8 @@ std::string buffer_format(const mx::array& a) {
return "f";
case mx::bfloat16:
return "B";
case mx::float64:
return "d";
case mx::complex64:
return "Zf\0";
default: {

View File

@@ -152,6 +152,8 @@ nb::ndarray<NDParams...> mlx_to_nd_array(const mx::array& a) {
throw nb::type_error("bfloat16 arrays cannot be converted to NumPy.");
case mx::float32:
return mlx_to_nd_array_impl<float, NDParams...>(a);
case mx::float64:
return mlx_to_nd_array_impl<double, NDParams...>(a);
case mx::complex64:
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
default: