2024-03-19 11:12:25 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
#include <nanobind/stl/complex.h>
|
|
|
|
|
|
|
|
#include "python/src/convert.h"
|
|
|
|
|
2024-03-26 04:29:45 +08:00
|
|
|
#include "mlx/utils.h"
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nanobind {
|
|
|
|
template <>
|
|
|
|
struct ndarray_traits<float16_t> {
|
|
|
|
static constexpr bool is_complex = false;
|
|
|
|
static constexpr bool is_float = true;
|
|
|
|
static constexpr bool is_bool = false;
|
|
|
|
static constexpr bool is_int = false;
|
|
|
|
static constexpr bool is_signed = true;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
struct ndarray_traits<bfloat16_t> {
|
|
|
|
static constexpr bool is_complex = false;
|
|
|
|
static constexpr bool is_float = true;
|
|
|
|
static constexpr bool is_bool = false;
|
|
|
|
static constexpr bool is_int = false;
|
|
|
|
static constexpr bool is_signed = true;
|
|
|
|
};
|
|
|
|
|
|
|
|
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
|
|
|
}; // namespace nanobind
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
array nd_array_to_mlx_contiguous(
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
|
|
|
const std::vector<int>& shape,
|
|
|
|
Dtype dtype) {
|
|
|
|
// Make a copy of the numpy buffer
|
|
|
|
// Get buffer ptr pass to array constructor
|
|
|
|
auto data_ptr = nd_array.data();
|
|
|
|
return array(static_cast<const T*>(data_ptr), shape, dtype);
|
|
|
|
}
|
|
|
|
|
|
|
|
array nd_array_to_mlx(
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
|
|
|
std::optional<Dtype> dtype) {
|
|
|
|
// Compute the shape and size
|
|
|
|
std::vector<int> shape;
|
|
|
|
for (int i = 0; i < nd_array.ndim(); i++) {
|
2024-03-26 04:29:45 +08:00
|
|
|
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
2024-03-19 11:12:25 +08:00
|
|
|
}
|
|
|
|
auto type = nd_array.dtype();
|
|
|
|
|
|
|
|
// Copy data and make array
|
|
|
|
if (type == nb::dtype<bool>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<bool>(
|
|
|
|
nd_array, shape, dtype.value_or(bool_));
|
|
|
|
} else if (type == nb::dtype<uint8_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<uint8_t>(
|
|
|
|
nd_array, shape, dtype.value_or(uint8));
|
|
|
|
} else if (type == nb::dtype<uint16_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<uint16_t>(
|
|
|
|
nd_array, shape, dtype.value_or(uint16));
|
|
|
|
} else if (type == nb::dtype<uint32_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<uint32_t>(
|
|
|
|
nd_array, shape, dtype.value_or(uint32));
|
|
|
|
} else if (type == nb::dtype<uint64_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<uint64_t>(
|
|
|
|
nd_array, shape, dtype.value_or(uint64));
|
|
|
|
} else if (type == nb::dtype<int8_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<int8_t>(
|
|
|
|
nd_array, shape, dtype.value_or(int8));
|
|
|
|
} else if (type == nb::dtype<int16_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<int16_t>(
|
|
|
|
nd_array, shape, dtype.value_or(int16));
|
|
|
|
} else if (type == nb::dtype<int32_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<int32_t>(
|
|
|
|
nd_array, shape, dtype.value_or(int32));
|
|
|
|
} else if (type == nb::dtype<int64_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<int64_t>(
|
|
|
|
nd_array, shape, dtype.value_or(int64));
|
|
|
|
} else if (type == nb::dtype<float16_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<float16_t>(
|
|
|
|
nd_array, shape, dtype.value_or(float16));
|
|
|
|
} else if (type == nb::dtype<bfloat16_t>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
|
|
|
nd_array, shape, dtype.value_or(bfloat16));
|
|
|
|
} else if (type == nb::dtype<float>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<float>(
|
|
|
|
nd_array, shape, dtype.value_or(float32));
|
|
|
|
} else if (type == nb::dtype<double>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<double>(
|
|
|
|
nd_array, shape, dtype.value_or(float32));
|
|
|
|
} else if (type == nb::dtype<std::complex<float>>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<complex64_t>(
|
|
|
|
nd_array, shape, dtype.value_or(complex64));
|
|
|
|
} else if (type == nb::dtype<std::complex<double>>()) {
|
|
|
|
return nd_array_to_mlx_contiguous<complex128_t>(
|
|
|
|
nd_array, shape, dtype.value_or(complex64));
|
|
|
|
} else {
|
|
|
|
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-17 07:11:37 +08:00
|
|
|
template <typename T, typename... NDParams>
|
|
|
|
nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
2024-03-19 11:12:25 +08:00
|
|
|
array a,
|
|
|
|
std::optional<nb::dlpack::dtype> t = {}) {
|
2024-04-17 21:16:02 +08:00
|
|
|
{
|
2024-03-19 11:12:25 +08:00
|
|
|
nb::gil_scoped_release nogil;
|
|
|
|
a.eval();
|
|
|
|
}
|
|
|
|
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
|
|
|
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
2024-05-17 07:11:37 +08:00
|
|
|
return nb::ndarray<NDParams...>(
|
2024-03-19 11:12:25 +08:00
|
|
|
a.data<T>(),
|
|
|
|
a.ndim(),
|
|
|
|
shape.data(),
|
2024-05-17 07:11:37 +08:00
|
|
|
nb::none(),
|
2024-03-19 11:12:25 +08:00
|
|
|
strides.data(),
|
|
|
|
t.value_or(nb::dtype<T>()));
|
|
|
|
}
|
|
|
|
|
2024-05-17 07:11:37 +08:00
|
|
|
template <typename... NDParams>
|
|
|
|
nb::ndarray<NDParams...> mlx_to_nd_array(const array& a) {
|
2024-03-19 11:12:25 +08:00
|
|
|
switch (a.dtype()) {
|
|
|
|
case bool_:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<bool, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case uint8:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<uint8_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case uint16:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<uint16_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case uint32:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<uint32_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case uint64:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<uint64_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case int8:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<int8_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case int16:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<int16_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case int32:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<int32_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case int64:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<int64_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case float16:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<float16_t, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case bfloat16:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<bfloat16_t, NDParams...>(a, nb::bfloat16);
|
2024-03-19 11:12:25 +08:00
|
|
|
case float32:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<float, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
case complex64:
|
2024-05-17 07:11:37 +08:00
|
|
|
return mlx_to_nd_array_impl<std::complex<float>, NDParams...>(a);
|
2024-03-19 11:12:25 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
|
|
|
return mlx_to_nd_array<nb::numpy>(a);
|
|
|
|
}
|
2024-05-17 07:11:37 +08:00
|
|
|
|
|
|
|
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
|
|
|
return mlx_to_nd_array<>(a);
|
|
|
|
}
|