mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion * fix getattribute on Module base * remove unused function * fix tests * add topk to ops * fix doc
This commit is contained in:
parent
8e281c76c3
commit
bc06cb9ff6
@ -123,6 +123,7 @@ Operations
|
|||||||
tanh
|
tanh
|
||||||
tensordot
|
tensordot
|
||||||
tile
|
tile
|
||||||
|
topk
|
||||||
transpose
|
transpose
|
||||||
tri
|
tri
|
||||||
tril
|
tril
|
||||||
|
@ -46,10 +46,6 @@ struct Dtype {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool is_available(const Dtype& dtype) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||||
|
|
||||||
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||||
|
@ -134,7 +134,7 @@ class Module(dict):
|
|||||||
if key in self:
|
if key in self:
|
||||||
return self[key]
|
return self[key]
|
||||||
else:
|
else:
|
||||||
super(Module, self).__getattr__(key, val)
|
super(Module, self).__getattribute__(key)
|
||||||
|
|
||||||
def __setattr__(self, key: str, val: Any):
|
def __setattr__(self, key: str, val: Any):
|
||||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include <pybind11/numpy.h>
|
#include <pybind11/numpy.h>
|
||||||
|
|
||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
|
#include "python/src/pybind11_numpy_fp16.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
|||||||
shape.push_back(np_array.shape(i));
|
shape.push_back(np_array.shape(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get dtype
|
|
||||||
auto type = np_array.dtype();
|
|
||||||
|
|
||||||
// Copy data and make array
|
// Copy data and make array
|
||||||
if (type.is(py::dtype::of<int>())) {
|
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<int32_t>(
|
return np_array_to_mlx_contiguous<int32_t>(
|
||||||
np_array, shape, dtype.value_or(int32));
|
np_array, shape, dtype.value_or(int32));
|
||||||
} else if (type.is(py::dtype::of<uint32_t>())) {
|
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<uint32_t>(
|
return np_array_to_mlx_contiguous<uint32_t>(
|
||||||
np_array, shape, dtype.value_or(uint32));
|
np_array, shape, dtype.value_or(uint32));
|
||||||
} else if (type.is(py::dtype::of<bool>())) {
|
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<bool>(
|
return np_array_to_mlx_contiguous<bool>(
|
||||||
np_array, shape, dtype.value_or(bool_));
|
np_array, shape, dtype.value_or(bool_));
|
||||||
} else if (type.is(py::dtype::of<double>())) {
|
} else if (py::isinstance<py::array_t<double>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<double>(
|
return np_array_to_mlx_contiguous<double>(
|
||||||
np_array, shape, dtype.value_or(float32));
|
np_array, shape, dtype.value_or(float32));
|
||||||
} else if (type.is(py::dtype::of<float>())) {
|
} else if (py::isinstance<py::array_t<float>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<float>(
|
return np_array_to_mlx_contiguous<float>(
|
||||||
np_array, shape, dtype.value_or(float32));
|
np_array, shape, dtype.value_or(float32));
|
||||||
} else if (type.is(py::dtype("float16"))) {
|
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<float>(
|
return np_array_to_mlx_contiguous<float>(
|
||||||
np_array, shape, dtype.value_or(float16));
|
np_array, shape, dtype.value_or(float16));
|
||||||
} else if (type.is(py::dtype::of<uint8_t>())) {
|
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<uint8_t>(
|
return np_array_to_mlx_contiguous<uint8_t>(
|
||||||
np_array, shape, dtype.value_or(uint8));
|
np_array, shape, dtype.value_or(uint8));
|
||||||
} else if (type.is(py::dtype::of<uint16_t>())) {
|
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<uint16_t>(
|
return np_array_to_mlx_contiguous<uint16_t>(
|
||||||
np_array, shape, dtype.value_or(uint16));
|
np_array, shape, dtype.value_or(uint16));
|
||||||
} else if (type.is(py::dtype::of<uint64_t>())) {
|
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<uint64_t>(
|
return np_array_to_mlx_contiguous<uint64_t>(
|
||||||
np_array, shape, dtype.value_or(uint64));
|
np_array, shape, dtype.value_or(uint64));
|
||||||
} else if (type.is(py::dtype::of<int8_t>())) {
|
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<int8_t>(
|
return np_array_to_mlx_contiguous<int8_t>(
|
||||||
np_array, shape, dtype.value_or(int8));
|
np_array, shape, dtype.value_or(int8));
|
||||||
} else if (type.is(py::dtype::of<int16_t>())) {
|
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<int16_t>(
|
return np_array_to_mlx_contiguous<int16_t>(
|
||||||
np_array, shape, dtype.value_or(int16));
|
np_array, shape, dtype.value_or(int16));
|
||||||
} else if (type.is(py::dtype::of<int64_t>())) {
|
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<int64_t>(
|
return np_array_to_mlx_contiguous<int64_t>(
|
||||||
np_array, shape, dtype.value_or(int64));
|
np_array, shape, dtype.value_or(int64));
|
||||||
} else if (type.is(py::dtype::of<std::complex<float>>())) {
|
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||||
np_array, shape, dtype.value_or(complex64));
|
np_array, shape, dtype.value_or(complex64));
|
||||||
} else if (type.is(py::dtype::of<std::complex<double>>())) {
|
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
|
||||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||||
np_array, shape, dtype.value_or(complex64));
|
np_array, shape, dtype.value_or(complex64));
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
|
msg << "Cannot convert numpy array of type " << np_array.dtype()
|
||||||
|
<< " to mlx array.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
60
python/src/pybind11_numpy_fp16.h
Normal file
60
python/src/pybind11_numpy_fp16.h
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// A patch to get float16_t to work with pybind11 numpy arrays
|
||||||
|
// Derived from:
|
||||||
|
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
|
||||||
|
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
|
||||||
|
namespace pybind11::detail {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct npy_scalar_caster {
|
||||||
|
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
|
||||||
|
using Array = array_t<T>;
|
||||||
|
|
||||||
|
bool load(handle src, bool convert) {
|
||||||
|
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
|
||||||
|
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
|
||||||
|
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
|
||||||
|
return false;
|
||||||
|
Array tmp = Array::ensure(src);
|
||||||
|
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
|
||||||
|
this->value = *tmp.data();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static handle cast(T src, return_value_policy, handle) {
|
||||||
|
Array tmp({1});
|
||||||
|
tmp.mutable_at(0) = src;
|
||||||
|
tmp.resize({});
|
||||||
|
// You could also just return the array if you want a scalar array.
|
||||||
|
object scalar = tmp[tuple()];
|
||||||
|
return scalar.release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
|
||||||
|
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
|
||||||
|
constexpr int NPY_FLOAT16 = 23;
|
||||||
|
|
||||||
|
// Kinda following:
|
||||||
|
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
|
||||||
|
template <>
|
||||||
|
struct npy_format_descriptor<float16_t> {
|
||||||
|
static constexpr auto name = _("float16");
|
||||||
|
static pybind11::dtype dtype() {
|
||||||
|
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
|
||||||
|
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
|
||||||
|
static constexpr auto name = _("float16");
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pybind11::detail
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
@ -1389,6 +1390,15 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
b @= a
|
b @= a
|
||||||
self.assertTrue(mx.array_equal(a, b))
|
self.assertTrue(mx.array_equal(a, b))
|
||||||
|
|
||||||
|
def test_load_from_pickled_np(self):
|
||||||
|
a = np.array([1, 2, 3], dtype=np.int32)
|
||||||
|
b = pickle.loads(pickle.dumps(a))
|
||||||
|
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||||
|
|
||||||
|
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
|
||||||
|
b = pickle.loads(pickle.dumps(a))
|
||||||
|
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -248,11 +248,9 @@ TEST_CASE("test random uniform") {
|
|||||||
CHECK_EQ(x.size(), 1);
|
CHECK_EQ(x.size(), 1);
|
||||||
CHECK_EQ(x.dtype(), float32);
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
if (is_available(float16)) {
|
x = random::uniform({}, float16);
|
||||||
x = random::uniform({}, float16);
|
CHECK_EQ(x.size(), 1);
|
||||||
CHECK_EQ(x.size(), 1);
|
CHECK_EQ(x.dtype(), float16);
|
||||||
CHECK_EQ(x.dtype(), float16);
|
|
||||||
}
|
|
||||||
|
|
||||||
x = random::uniform({0});
|
x = random::uniform({0});
|
||||||
CHECK(array_equal(x, array({})).item<bool>());
|
CHECK(array_equal(x, array({})).item<bool>());
|
||||||
@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") {
|
|||||||
CHECK_EQ(x.dtype(), bool_);
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
|
||||||
// Bernoulli parameter can have floating point type
|
// Bernoulli parameter can have floating point type
|
||||||
if (is_available(float16)) {
|
x = random::bernoulli(array(0.5, float16));
|
||||||
x = random::bernoulli(array(0.5, float16));
|
CHECK_EQ(x.size(), 1);
|
||||||
CHECK_EQ(x.size(), 1);
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
CHECK_EQ(x.dtype(), bool_);
|
|
||||||
}
|
|
||||||
|
|
||||||
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
||||||
|
|
||||||
@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") {
|
|||||||
CHECK_EQ(x.size(), 1);
|
CHECK_EQ(x.size(), 1);
|
||||||
CHECK_EQ(x.dtype(), float32);
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
if (is_available(float16)) {
|
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
CHECK_EQ(x.size(), 1);
|
||||||
CHECK_EQ(x.size(), 1);
|
CHECK_EQ(x.dtype(), float16);
|
||||||
CHECK_EQ(x.dtype(), float16);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Requested shape
|
// Requested shape
|
||||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||||
|
Loading…
Reference in New Issue
Block a user