mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion * fix getattribute on Module base * remove unused function * fix tests * add topk to ops * fix doc
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/pybind11_numpy_fp16.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
@@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
||||
shape.push_back(np_array.shape(i));
|
||||
}
|
||||
|
||||
// Get dtype
|
||||
auto type = np_array.dtype();
|
||||
|
||||
// Copy data and make array
|
||||
if (type.is(py::dtype::of<int>())) {
|
||||
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int32_t>(
|
||||
np_array, shape, dtype.value_or(int32));
|
||||
} else if (type.is(py::dtype::of<uint32_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint32_t>(
|
||||
np_array, shape, dtype.value_or(uint32));
|
||||
} else if (type.is(py::dtype::of<bool>())) {
|
||||
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<bool>(
|
||||
np_array, shape, dtype.value_or(bool_));
|
||||
} else if (type.is(py::dtype::of<double>())) {
|
||||
} else if (py::isinstance<py::array_t<double>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<double>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype::of<float>())) {
|
||||
} else if (py::isinstance<py::array_t<float>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype("float16"))) {
|
||||
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float16));
|
||||
} else if (type.is(py::dtype::of<uint8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint8_t>(
|
||||
np_array, shape, dtype.value_or(uint8));
|
||||
} else if (type.is(py::dtype::of<uint16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint16_t>(
|
||||
np_array, shape, dtype.value_or(uint16));
|
||||
} else if (type.is(py::dtype::of<uint64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint64_t>(
|
||||
np_array, shape, dtype.value_or(uint64));
|
||||
} else if (type.is(py::dtype::of<int8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int8_t>(
|
||||
np_array, shape, dtype.value_or(int8));
|
||||
} else if (type.is(py::dtype::of<int16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int16_t>(
|
||||
np_array, shape, dtype.value_or(int16));
|
||||
} else if (type.is(py::dtype::of<int64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int64_t>(
|
||||
np_array, shape, dtype.value_or(int64));
|
||||
} else if (type.is(py::dtype::of<std::complex<float>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else if (type.is(py::dtype::of<std::complex<double>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
|
||||
msg << "Cannot convert numpy array of type " << np_array.dtype()
|
||||
<< " to mlx array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user