diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index a7809ead2..6396bb3c6 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -123,6 +123,7 @@ Operations tanh tensordot tile + topk transpose tri tril diff --git a/mlx/dtype.h b/mlx/dtype.h index d52830485..fec1725f3 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -46,10 +46,6 @@ struct Dtype { }; }; -inline bool is_available(const Dtype& dtype) { - return true; -} - static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index a770a5d95..447002594 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -134,7 +134,7 @@ class Module(dict): if key in self: return self[key] else: - super(Module, self).__getattr__(key, val) + super(Module, self).__getattribute__(key) def __setattr__(self, key: str, val: Any): if isinstance(val, (mx.array, dict, list, tuple)): diff --git a/python/src/array.cpp b/python/src/array.cpp index 6dd2f290b..8e0cf605b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -7,6 +7,7 @@ #include #include "python/src/indexing.h" +#include "python/src/pybind11_numpy_fp16.h" #include "python/src/utils.h" #include "mlx/ops.h" @@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional dtype) { shape.push_back(np_array.shape(i)); } - // Get dtype - auto type = np_array.dtype(); - // Copy data and make array - if (type.is(py::dtype::of())) { + if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int32)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint32)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(bool_)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float32)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float32)); - } else if (type.is(py::dtype("float16"))) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(float16)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint8)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint16)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(uint64)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int8)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int16)); - } else if (type.is(py::dtype::of())) { + } else if (py::isinstance>(np_array)) { return np_array_to_mlx_contiguous( np_array, shape, dtype.value_or(int64)); - } else if (type.is(py::dtype::of>())) { + } else if (py::isinstance>>(np_array)) { return np_array_to_mlx_contiguous>( np_array, shape, dtype.value_or(complex64)); - } else if (type.is(py::dtype::of>())) { + } else if (py::isinstance>>(np_array)) { return np_array_to_mlx_contiguous>( np_array, shape, dtype.value_or(complex64)); } else { std::ostringstream msg; - msg << "Cannot convert numpy array of type " << type << " to mlx array."; + msg << "Cannot convert numpy array of type " << np_array.dtype() + << " to mlx array."; throw std::invalid_argument(msg.str()); } } diff --git a/python/src/pybind11_numpy_fp16.h b/python/src/pybind11_numpy_fp16.h new file mode 100644 index 000000000..ed496524f --- /dev/null +++ b/python/src/pybind11_numpy_fp16.h @@ -0,0 +1,60 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +// A patch to get float16_t to work with pybind11 numpy arrays +// Derived from: +// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679 + +#include + +namespace pybind11::detail { + +template +struct npy_scalar_caster { + PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); + using Array = array_t; + + bool load(handle src, bool convert) { + // Taken from Eigen casters. Permits either scalar dtype or scalar array. + handle type = dtype::of().attr("type"); // Could make more efficient. + if (!convert && !isinstance(src) && !isinstance(src, type)) + return false; + Array tmp = Array::ensure(src); + if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { + this->value = *tmp.data(); + return true; + } + return false; + } + + static handle cast(T src, return_value_policy, handle) { + Array tmp({1}); + tmp.mutable_at(0) = src; + tmp.resize({}); + // You could also just return the array if you want a scalar array. + object scalar = tmp[tuple()]; + return scalar.release(); + } +}; + +// Similar to enums in `pybind11/numpy.h`. Determined by doing: +// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' +constexpr int NPY_FLOAT16 = 23; + +// Kinda following: +// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000 +template <> +struct npy_format_descriptor { + static constexpr auto name = _("float16"); + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); + return reinterpret_borrow(ptr); + } +}; + +template <> +struct type_caster : npy_scalar_caster { + static constexpr auto name = _("float16"); +}; + +} // namespace pybind11::detail diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 7812642d3..828c2c0b4 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import operator +import pickle import unittest import weakref from itertools import permutations @@ -1389,6 +1390,15 @@ class TestArray(mlx_tests.MLXTestCase): b @= a self.assertTrue(mx.array_equal(a, b)) + def test_load_from_pickled_np(self): + a = np.array([1, 2, 3], dtype=np.int32) + b = pickle.loads(pickle.dumps(a)) + self.assertTrue(mx.array_equal(mx.array(a), mx.array(b))) + + a = np.array([1.0, 2.0, 3.0], dtype=np.float16) + b = pickle.loads(pickle.dumps(a)) + self.assertTrue(mx.array_equal(mx.array(a), mx.array(b))) + if __name__ == "__main__": unittest.main() diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index b7793e41c..7ce057319 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -248,11 +248,9 @@ TEST_CASE("test random uniform") { CHECK_EQ(x.size(), 1); CHECK_EQ(x.dtype(), float32); - if (is_available(float16)) { - x = random::uniform({}, float16); - CHECK_EQ(x.size(), 1); - CHECK_EQ(x.dtype(), float16); - } + x = random::uniform({}, float16); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float16); x = random::uniform({0}); CHECK(array_equal(x, array({})).item()); @@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") { CHECK_EQ(x.dtype(), bool_); // Bernoulli parameter can have floating point type - if (is_available(float16)) { - x = random::bernoulli(array(0.5, float16)); - CHECK_EQ(x.size(), 1); - CHECK_EQ(x.dtype(), bool_); - } + x = random::bernoulli(array(0.5, float16)); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), bool_); CHECK_THROWS(random::bernoulli(array(1, int32))); @@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") { CHECK_EQ(x.size(), 1); CHECK_EQ(x.dtype(), float32); - if (is_available(float16)) { - x = random::truncated_normal(array(-2.0), array(2.0), {}, float16); - CHECK_EQ(x.size(), 1); - CHECK_EQ(x.dtype(), float16); - } + x = random::truncated_normal(array(-2.0), array(2.0), {}, float16); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float16); // Requested shape x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});