Pickle + dtype fix for numpy conversion (#763)

* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
This commit is contained in:
Awni Hannun
2024-03-02 06:09:29 -08:00
committed by GitHub
parent 8e281c76c3
commit bc06cb9ff6
7 changed files with 99 additions and 39 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <cstring>
@@ -7,6 +7,7 @@
#include <pybind11/numpy.h>
#include "python/src/indexing.h"
#include "python/src/pybind11_numpy_fp16.h"
#include "python/src/utils.h"
#include "mlx/ops.h"
@@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
shape.push_back(np_array.shape(i));
}
// Get dtype
auto type = np_array.dtype();
// Copy data and make array
if (type.is(py::dtype::of<int>())) {
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
return np_array_to_mlx_contiguous<int32_t>(
np_array, shape, dtype.value_or(int32));
} else if (type.is(py::dtype::of<uint32_t>())) {
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint32_t>(
np_array, shape, dtype.value_or(uint32));
} else if (type.is(py::dtype::of<bool>())) {
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
return np_array_to_mlx_contiguous<bool>(
np_array, shape, dtype.value_or(bool_));
} else if (type.is(py::dtype::of<double>())) {
} else if (py::isinstance<py::array_t<double>>(np_array)) {
return np_array_to_mlx_contiguous<double>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype::of<float>())) {
} else if (py::isinstance<py::array_t<float>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype("float16"))) {
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float16));
} else if (type.is(py::dtype::of<uint8_t>())) {
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint8_t>(
np_array, shape, dtype.value_or(uint8));
} else if (type.is(py::dtype::of<uint16_t>())) {
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint16_t>(
np_array, shape, dtype.value_or(uint16));
} else if (type.is(py::dtype::of<uint64_t>())) {
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint64_t>(
np_array, shape, dtype.value_or(uint64));
} else if (type.is(py::dtype::of<int8_t>())) {
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
return np_array_to_mlx_contiguous<int8_t>(
np_array, shape, dtype.value_or(int8));
} else if (type.is(py::dtype::of<int16_t>())) {
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
return np_array_to_mlx_contiguous<int16_t>(
np_array, shape, dtype.value_or(int16));
} else if (type.is(py::dtype::of<int64_t>())) {
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
return np_array_to_mlx_contiguous<int64_t>(
np_array, shape, dtype.value_or(int64));
} else if (type.is(py::dtype::of<std::complex<float>>())) {
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else if (type.is(py::dtype::of<std::complex<double>>())) {
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else {
std::ostringstream msg;
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
msg << "Cannot convert numpy array of type " << np_array.dtype()
<< " to mlx array.";
throw std::invalid_argument(msg.str());
}
}

View File

@@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// A patch to get float16_t to work with pybind11 numpy arrays
// Derived from:
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
#include <pybind11/numpy.h>
namespace pybind11::detail {
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
// Kinda following:
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
template <>
struct npy_format_descriptor<float16_t> {
static constexpr auto name = _("float16");
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
};
template <>
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
static constexpr auto name = _("float16");
};
} // namespace pybind11::detail