mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
allow conversion to dlpack (#1120)
This commit is contained in:
@@ -100,8 +100,8 @@ array nd_array_to_mlx(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Lib, typename T>
|
||||
nb::ndarray<Lib> mlx_to_nd_array(
|
||||
template <typename T, typename... NDParams>
|
||||
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
array a,
|
||||
std::optional<nb::dlpack::dtype> t = {}) {
|
||||
{
|
||||
@@ -110,47 +110,51 @@ nb::ndarray<Lib> mlx_to_nd_array(
|
||||
}
|
||||
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
||||
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
||||
return nb::ndarray<Lib>(
|
||||
return nb::ndarray<NDParams...>(
|
||||
a.data<T>(),
|
||||
a.ndim(),
|
||||
shape.data(),
|
||||
nb::handle(),
|
||||
nb::none(),
|
||||
strides.data(),
|
||||
t.value_or(nb::dtype<T>()));
|
||||
}
|
||||
|
||||
template <typename Lib>
|
||||
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
|
||||
template <typename... NDParams>
|
||||
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return mlx_to_nd_array<Lib, bool>(a);
|
||||
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
||||
case uint8:
|
||||
return mlx_to_nd_array<Lib, uint8_t>(a);
|
||||
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
||||
case uint16:
|
||||
return mlx_to_nd_array<Lib, uint16_t>(a);
|
||||
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
||||
case uint32:
|
||||
return mlx_to_nd_array<Lib, uint32_t>(a);
|
||||
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
||||
case uint64:
|
||||
return mlx_to_nd_array<Lib, uint64_t>(a);
|
||||
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
||||
case int8:
|
||||
return mlx_to_nd_array<Lib, int8_t>(a);
|
||||
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
||||
case int16:
|
||||
return mlx_to_nd_array<Lib, int16_t>(a);
|
||||
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
||||
case int32:
|
||||
return mlx_to_nd_array<Lib, int32_t>(a);
|
||||
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
||||
case int64:
|
||||
return mlx_to_nd_array<Lib, int64_t>(a);
|
||||
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
||||
case float16:
|
||||
return mlx_to_nd_array<Lib, float16_t>(a);
|
||||
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
||||
case bfloat16:
|
||||
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
|
||||
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
|
||||
case float32:
|
||||
return mlx_to_nd_array<Lib, float>(a);
|
||||
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
||||
case complex64:
|
||||
return mlx_to_nd_array<Lib, std::complex<float>>(a);
|
||||
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
||||
}
|
||||
}
|
||||
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
||||
return mlx_to_nd_array<nb::numpy>(a);
|
||||
}
|
||||
|
||||
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
||||
return mlx_to_nd_array<>(a);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user