diff --git a/.circleci/config.yml b/.circleci/config.yml index cd466fb72..6f8e613f8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -31,8 +31,7 @@ jobs: name: Install dependencies command: | pip install --upgrade cmake - pip install --upgrade pybind11[global] - pip install pybind11-stubgen + pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f pip install numpy sudo apt-get update sudo apt-get install libblas-dev liblapack-dev liblapacke-dev @@ -44,7 +43,8 @@ jobs: - run: name: Generate package stubs command: | - python3 setup.py generate_stubs + echo "stubs" + python -m nanobind.stubgen -m mlx.core -r -O python - run: name: Run Python tests command: | @@ -80,8 +80,7 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install --upgrade pybind11[global] - pip install pybind11-stubgen + pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f pip install numpy pip install torch pip install tensorflow @@ -95,7 +94,7 @@ jobs: name: Generate package stubs command: | source env/bin/activate - python setup.py generate_stubs + python -m nanobind.stubgen -m mlx.core -r -O python - run: name: Run Python tests command: | @@ -144,9 +143,8 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install --upgrade pybind11[global] + pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f pip install --upgrade setuptools - pip install pybind11-stubgen pip install numpy pip install twine pip install build @@ -161,7 +159,7 @@ jobs: name: Generate package stubs command: | source env/bin/activate - python setup.py generate_stubs + python -m nanobind.stubgen -m mlx.core -r -O python - run: name: Build Python package command: | @@ -209,9 +207,8 @@ jobs: source env/bin/activate pip install --upgrade pip pip install --upgrade cmake - pip install --upgrade pybind11[global] + pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f pip install --upgrade setuptools - pip install pybind11-stubgen pip install numpy pip install auditwheel pip install patchelf @@ -219,7 +216,7 @@ jobs: << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ pip install . -v - python setup.py generate_stubs + python -m nanobind.stubgen -m mlx.core -r -O python << parameters.extra_env >> \ CMAKE_BUILD_PARALLEL_LEVEL="" \ python -m build --wheel diff --git a/CMakeLists.txt b/CMakeLists.txt index 6aecd26f2..6d7e177a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,8 +146,12 @@ target_include_directories( if (MLX_BUILD_PYTHON_BINDINGS) message(STATUS "Building Python bindings.") - find_package(Python COMPONENTS Interpreter Development) - find_package(pybind11 CONFIG REQUIRED) + find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) + list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") + find_package(nanobind CONFIG REQUIRED) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) endif() diff --git a/docs/src/conf.py b/docs/src/conf.py index 603bfa847..c85ee9510 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -29,8 +29,8 @@ autosummary_generate = True autosummary_filename_map = {"mlx.core.Stream": "stream_class"} intersphinx_mapping = { - "https://docs.python.org/3": None, - "https://numpy.org/doc/stable/": None, + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), } templates_path = ["_templates"] @@ -59,3 +59,14 @@ html_theme_options = { # -- Options for HTMLHelp output --------------------------------------------- htmlhelp_basename = "mlx_doc" + + +def setup(app): + wrapped = app.registry.documenters["function"].can_document_member + + def nanobind_function_patch(member: Any, *args, **kwargs) -> bool: + return "nanobind.nb_func" in str(type(member)) or wrapped( + member, *args, **kwargs + ) + + app.registry.documenters["function"].can_document_member = nanobind_function_patch diff --git a/docs/src/install.rst b/docs/src/install.rst index ce549247f..43571f95d 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from git clone git@github.com:ml-explore/mlx.git mlx && cd mlx -Make sure that you have `pybind11 `_ -installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows: +Install `nanobind `_ with: .. code-block:: shell - pip install "pybind11[global]" - conda install pybind11 - brew install pybind11 + pip install git+https://github.com/wjakob/nanobind.git -Then simply build and install it using pip: +Then simply build and install MLX using pip: .. code-block:: shell diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 9d88f87fe..3d27ab04e 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -1,5 +1,4 @@ -// Copyright © 2023 Apple Inc. - +// Copyright © 2023-2024 Apple Inc. #include #include #include @@ -122,8 +121,6 @@ void save(std::shared_ptr out_stream, array a) { out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length()); out_stream->write(header.str().c_str(), header.str().length()); out_stream->write(a.data(), a.nbytes()); - - return; } /** Save array to file in .npy format */ diff --git a/mlx/types/complex.h b/mlx/types/complex.h index f8a607766..539974d33 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -7,6 +7,25 @@ namespace mlx::core { struct complex64_t; +struct complex128_t; + +template +static constexpr bool can_convert_to_complex128 = + !std::is_same_v && std::is_convertible_v; + +struct complex128_t : public std::complex { + complex128_t(double v, double u) : std::complex(v, u){}; + complex128_t(std::complex v) : std::complex(v){}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex128_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; template static constexpr bool can_convert_to_complex64 = diff --git a/pyproject.toml b/pyproject.toml index 09ead06ee..b9511111e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ [build-system] -requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"] +requires = [ + "setuptools>=42", + "nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f", + "cmake>=3.24", +] build-backend = "setuptools.build_meta" diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 7a3729436..83acef1e2 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -1,7 +1,10 @@ -pybind11_add_module( +nanobind_add_module( core + NB_STATIC STABLE_ABI LTO NOMINSIZE + NB_DOMAIN mlx ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp @@ -15,7 +18,6 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/array.cpp b/python/src/array.cpp index dd8ad89e5..196561c16 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1,21 +1,27 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include -#include +#include +#include +#include +#include +#include +#include +#include "python/src/buffer.h" +#include "python/src/convert.h" #include "python/src/indexing.h" -#include "python/src/pybind11_numpy_fp16.h" #include "python/src/utils.h" #include "mlx/ops.h" #include "mlx/transforms.h" #include "mlx/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlx::core; enum PyScalarT { pybool = 0, @@ -25,8 +31,8 @@ enum PyScalarT { }; template -py::list to_list(array& a, size_t index, int dim) { - py::list pl; +nb::list to_list(array& a, size_t index, int dim) { + nb::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { if (dim == a.ndim() - 1) { @@ -41,48 +47,47 @@ py::list to_list(array& a, size_t index, int dim) { auto to_scalar(array& a) { { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; a.eval(); } switch (a.dtype()) { case bool_: - return py::cast(a.item()); + return nb::cast(a.item()); case uint8: - return py::cast(a.item()); + return nb::cast(a.item()); case uint16: - return py::cast(a.item()); + return nb::cast(a.item()); case uint32: - return py::cast(a.item()); + return nb::cast(a.item()); case uint64: - return py::cast(a.item()); + return nb::cast(a.item()); case int8: - return py::cast(a.item()); + return nb::cast(a.item()); case int16: - return py::cast(a.item()); + return nb::cast(a.item()); case int32: - return py::cast(a.item()); + return nb::cast(a.item()); case int64: - return py::cast(a.item()); + return nb::cast(a.item()); case float16: - return py::cast(static_cast(a.item())); + return nb::cast(static_cast(a.item())); case float32: - return py::cast(a.item()); + return nb::cast(a.item()); case bfloat16: - return py::cast(static_cast(a.item())); + return nb::cast(static_cast(a.item())); case complex64: - return py::cast(a.item>()); + return nb::cast(a.item>()); } } -py::object tolist(array& a) { +nb::object tolist(array& a) { if (a.ndim() == 0) { return to_scalar(a); } { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; a.eval(); } - py::object pl; switch (a.dtype()) { case bool_: return to_list(a, 0, 0); @@ -116,12 +121,12 @@ py::object tolist(array& a) { template void fill_vector(T list, std::vector& vals) { for (auto l : list) { - if (py::isinstance(l)) { - fill_vector(l.template cast(), vals); - } else if (py::isinstance(*list.begin())) { - fill_vector(l.template cast(), vals); + if (nb::isinstance(l)) { + fill_vector(nb::cast(l), vals); + } else if (nb::isinstance(*list.begin())) { + fill_vector(nb::cast(l), vals); } else { - vals.push_back(l.template cast()); + vals.push_back(nb::cast(l)); } } } @@ -136,7 +141,7 @@ PyScalarT validate_shape( throw std::invalid_argument("Initialization encountered extra dimension."); } auto s = shape[idx]; - if (py::len(list) != s) { + if (nb::len(list) != s) { throw std::invalid_argument( "Initialization encountered non-uniform length."); } @@ -148,29 +153,26 @@ PyScalarT validate_shape( PyScalarT type = pybool; for (auto l : list) { PyScalarT t; - if (py::isinstance(l)) { + if (nb::isinstance(l)) { t = validate_shape( - l.template cast(), + nb::cast(l), shape, idx + 1, all_python_primitive_elements); + } else if (nb::isinstance(*list.begin())) { + t = validate_shape( + nb::cast(l), shape, idx + 1, all_python_primitive_elements); - } else if (py::isinstance(*list.begin())) { - t = validate_shape( - l.template cast(), - shape, - idx + 1, - all_python_primitive_elements); - } else if (py::isinstance(l)) { + } else if (nb::isinstance(l)) { t = pybool; - } else if (py::isinstance(l)) { + } else if (nb::isinstance(l)) { t = pyint; - } else if (py::isinstance(l)) { + } else if (nb::isinstance(l)) { t = pyfloat; } else if (PyComplex_Check(l.ptr())) { t = pycomplex; - } else if (py::isinstance(l)) { + } else if (nb::isinstance(l)) { all_python_primitive_elements = false; - auto arr = py::cast(l); + auto arr = nb::cast(l); if (arr.ndim() + idx + 1 == shape.size() && std::equal( arr.shape().cbegin(), @@ -183,7 +185,8 @@ PyScalarT validate_shape( } } else { std::ostringstream msg; - msg << "Invalid type in array initialization" << l.get_type() << "."; + msg << "Invalid type " << nb::type_name(l.type()).c_str() + << " received in array initialization."; throw std::invalid_argument(msg.str()); } type = std::max(type, t); @@ -193,35 +196,36 @@ PyScalarT validate_shape( template void get_shape(T list, std::vector& shape) { - shape.push_back(py::len(list)); + shape.push_back(nb::len(list)); if (shape.back() > 0) { - auto& l = *list.begin(); - if (py::isinstance(l)) { - return get_shape(l.template cast(), shape); - } else if (py::isinstance(l)) { - return get_shape(l.template cast(), shape); - } else if (py::isinstance(l)) { - auto arr = py::cast(l); + auto l = list.begin(); + if (nb::isinstance(*l)) { + return get_shape(nb::cast(*l), shape); + } else if (nb::isinstance(*l)) { + return get_shape(nb::cast(*l), shape); + } else if (nb::isinstance(*l)) { + auto arr = nb::cast(*l); shape.insert(shape.end(), arr.shape().begin(), arr.shape().end()); return; } } } -using array_init_type = std::variant< - py::bool_, - py::int_, - py::float_, - std::complex, - py::list, - py::tuple, +using ArrayInitType = std::variant< + nb::bool_, + nb::int_, + nb::float_, + // Must be above ndarray array, - py::array, - py::buffer, - py::object>; + // Must be above complex + nb::ndarray, + std::complex, + nb::list, + nb::tuple, + nb::object>; // Forward declaration -array create_array(array_init_type v, std::optional t); +array create_array(ArrayInitType v, std::optional t); template array array_from_list( @@ -300,204 +304,36 @@ array array_from_list(T pl, std::optional dtype) { // `pl` contains mlx arrays std::vector arrays; for (auto l : pl) { - arrays.push_back(create_array(py::cast(l), dtype)); + arrays.push_back(create_array(nb::cast(l), dtype)); } return stack(arrays); } -/////////////////////////////////////////////////////////////////////////////// -// Numpy -> MLX -/////////////////////////////////////////////////////////////////////////////// - -template -array np_array_to_mlx_contiguous( - py::array_t np_array, - const std::vector& shape, - Dtype dtype) { - // Make a copy of the numpy buffer - // Get buffer ptr pass to array constructor - py::buffer_info buf = np_array.request(); - const T* data_ptr = static_cast(buf.ptr); - return array(data_ptr, shape, dtype); - - // Note: Leaving the following memoryless copy from np to mx commented - // out for the time being since it is unsafe given that the incoming - // numpy array may change the underlying data - - // // Share underlying numpy buffer - // // Copy to increase ref count - // auto deleter = [np_array](void*) {}; - // void* data_ptr = np_array.mutable_data(); - // // Use buffer from numpy - // return array(data_ptr, deleter, shape, dtype); -} - -template <> -array np_array_to_mlx_contiguous( - py::array_t, py::array::c_style | py::array::forcecast> - np_array, - const std::vector& shape, - Dtype dtype) { - // Get buffer ptr pass to array constructor - py::buffer_info buf = np_array.request(); - auto data_ptr = static_cast*>(buf.ptr); - return array(reinterpret_cast(data_ptr), shape, dtype); -} - -array np_array_to_mlx(py::array np_array, std::optional dtype) { - // Compute the shape and size - std::vector shape; - for (int i = 0; i < np_array.ndim(); i++) { - shape.push_back(np_array.shape(i)); - } - - // Copy data and make array - if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(int32)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(uint32)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(bool_)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(float32)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(float32)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(float16)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(uint8)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(uint16)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(uint64)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(int8)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(int16)); - } else if (py::isinstance>(np_array)) { - return np_array_to_mlx_contiguous( - np_array, shape, dtype.value_or(int64)); - } else if (py::isinstance>>(np_array)) { - return np_array_to_mlx_contiguous>( - np_array, shape, dtype.value_or(complex64)); - } else if (py::isinstance>>(np_array)) { - return np_array_to_mlx_contiguous>( - np_array, shape, dtype.value_or(complex64)); - } else { - std::ostringstream msg; - msg << "Cannot convert numpy array of type " << np_array.dtype() - << " to mlx array."; - throw std::invalid_argument(msg.str()); - } -} - -/////////////////////////////////////////////////////////////////////////////// -// Python Buffer Protocol (MLX -> Numpy) -/////////////////////////////////////////////////////////////////////////////// - -std::optional buffer_format(const array& a) { - // https://docs.python.org/3.10/library/struct.html#format-characters - switch (a.dtype()) { - case bool_: - return pybind11::format_descriptor::format(); - case uint8: - return pybind11::format_descriptor::format(); - case uint16: - return pybind11::format_descriptor::format(); - case uint32: - return pybind11::format_descriptor::format(); - case uint64: - return pybind11::format_descriptor::format(); - case int8: - return pybind11::format_descriptor::format(); - case int16: - return pybind11::format_descriptor::format(); - case int32: - return pybind11::format_descriptor::format(); - case int64: - return pybind11::format_descriptor::format(); - case float16: - // https://github.com/pybind/pybind11/issues/4998 - return "e"; - case float32: { - return pybind11::format_descriptor::format(); - } - case bfloat16: - // not supported by python buffer protocol or numpy. - // must be null according to - // https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT - // which implies 'B'. - return {}; - case complex64: - return pybind11::format_descriptor>::format(); - default: { - std::ostringstream os; - os << "bad dtype: " << a.dtype(); - throw std::runtime_error(os.str()); - } - } -} - -std::vector buffer_strides(const array& a) { - std::vector py_strides; - py_strides.reserve(a.strides().size()); - for (const size_t stride : a.strides()) { - py_strides.push_back(stride * a.itemsize()); - } - return py_strides; -} - -py::buffer_info buffer_info(array& a) { - // Eval if not already evaled - if (!a.is_evaled()) { - py::gil_scoped_release nogil; - a.eval(); - } - return pybind11::buffer_info( - a.data(), - a.itemsize(), - buffer_format(a).value_or("B"), // we use "B" because pybind uses a - // std::string which can't be null - a.ndim(), - a.shape(), - buffer_strides(a)); -} /////////////////////////////////////////////////////////////////////////////// // Module /////////////////////////////////////////////////////////////////////////////// -array create_array(array_init_type v, std::optional t) { - if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(bool_)); - } else if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(int32)); - } else if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(float32)); +array create_array(ArrayInitType v, std::optional t) { + if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), t.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), t.value_or(int32)); + } else if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), t.value_or(float32)); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), t.value_or(complex64)); - } else if (auto pv = std::get_if(&v); pv) { + } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { + } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); + } else if (auto pv = std::get_if< + nb::ndarray>(&v); + pv) { + return nd_array_to_mlx(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return astype(*pv, t.value_or((*pv).dtype())); - } else if (auto pv = std::get_if(&v); pv) { - return np_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return np_array_to_mlx(*pv, t); } else { - auto arr = to_array_with_accessor(std::get(v)); + auto arr = to_array_with_accessor(std::get(v)); return astype(arr, t.value_or(arr.dtype())); } } @@ -505,7 +341,7 @@ array create_array(array_init_type v, std::optional t) { class ArrayAt { public: ArrayAt(array x) : x_(std::move(x)) {} - ArrayAt& set_indices(py::object indices) { + ArrayAt& set_indices(nb::object indices) { indices_ = indices; return *this; } @@ -530,7 +366,7 @@ class ArrayAt { private: array x_; - py::object indices_; + nb::object indices_; }; class ArrayPythonIterator { @@ -543,7 +379,7 @@ class ArrayPythonIterator { array next() { if (idx_ >= x_.shape(0)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } if (idx_ >= 0 && idx_ < splits_.size()) { @@ -559,12 +395,12 @@ class ArrayPythonIterator { std::vector splits_; }; -void init_array(py::module_& m) { +void init_array(nb::module_& m) { // Set Python print formatting options mlx::core::global_formatter.capitalize_bool = true; // Types - py::class_( + nb::class_( m, "Dtype", R"pbdoc( @@ -573,8 +409,7 @@ void init_array(py::module_& m) { See the :ref:`list of types ` for more details on available data types. )pbdoc") - .def_readonly( - "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") + .def_ro("size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") .def( "__repr__", [](const Dtype& t) { @@ -583,67 +418,39 @@ void init_array(py::module_& m) { os << t; return os.str(); }) - .def("__eq__", [](const Dtype& t1, const Dtype& t2) { return t1 == t2; }) + .def( + "__eq__", + [](const Dtype& t, const nb::object& other) { + return nb::isinstance(other) && t == nb::cast(other); + }) .def("__hash__", [](const Dtype& t) { return static_cast(t.val); }); - m.attr("bool_") = py::cast(bool_); - m.attr("uint8") = py::cast(uint8); - m.attr("uint16") = py::cast(uint16); - m.attr("uint32") = py::cast(uint32); - m.attr("uint64") = py::cast(uint64); - m.attr("int8") = py::cast(int8); - m.attr("int16") = py::cast(int16); - m.attr("int32") = py::cast(int32); - m.attr("int64") = py::cast(int64); - m.attr("float16") = py::cast(float16); - m.attr("float32") = py::cast(float32); - m.attr("bfloat16") = py::cast(bfloat16); - m.attr("complex64") = py::cast(complex64); + m.attr("bool_") = nb::cast(bool_); + m.attr("uint8") = nb::cast(uint8); + m.attr("uint16") = nb::cast(uint16); + m.attr("uint32") = nb::cast(uint32); + m.attr("uint64") = nb::cast(uint64); + m.attr("int8") = nb::cast(int8); + m.attr("int16") = nb::cast(int16); + m.attr("int32") = nb::cast(int32); + m.attr("int64") = nb::cast(int64); + m.attr("float16") = nb::cast(float16); + m.attr("float32") = nb::cast(float32); + m.attr("bfloat16") = nb::cast(bfloat16); + m.attr("complex64") = nb::cast(complex64); - auto array_at_class = py::class_( + nb::class_( m, "_ArrayAt", R"pbdoc( A helper object to apply updates at specific indices. - )pbdoc"); - - auto array_iterator_class = py::class_( - m, - "_ArrayIterator", - R"pbdoc( - A helper object to iterate over the 1st dimension of an array. - )pbdoc"); - - auto array_class = py::class_( - m, - "array", - R"pbdoc(An N-dimensional array object.)pbdoc", - py::buffer_protocol()); - - { - py::options options; - options.disable_function_signatures(); - - array_class.def( - py::init([](array_init_type v, std::optional t) { - return create_array(v, t); - }), - "val"_a, - "dtype"_a = std::nullopt, - R"pbdoc( - __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None) - )pbdoc"); - } - - array_at_class + )pbdoc") .def( - py::init([](const array& x) { return ArrayAt(x); }), + nb::init(), "x"_a, - R"pbdoc( - __init__(self, x: array) - )pbdoc") - .def("__getitem__", &ArrayAt::set_indices, "indices"_a) + nb::sig("def __init__(self, x: array)")) + .def("__getitem__", &ArrayAt::set_indices, "indices"_a.none()) .def("add", &ArrayAt::add, "value"_a) .def("subtract", &ArrayAt::subtract, "value"_a) .def("multiply", &ArrayAt::multiply, "value"_a) @@ -651,40 +458,61 @@ void init_array(py::module_& m) { .def("maximum", &ArrayAt::maximum, "value"_a) .def("minimum", &ArrayAt::minimum, "value"_a); - array_iterator_class + nb::class_( + m, + "_ArrayIterator", + R"pbdoc( + A helper object to iterate over the 1st dimension of an array. + )pbdoc") .def( - py::init([](const array& x) { return ArrayPythonIterator(x); }), + nb::init(), "x"_a, - R"pbdoc( - __init__(self, x: array) - )pbdoc") + nb::sig("def __init__(self, x: array)")) .def("__next__", &ArrayPythonIterator::next) .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); - array_class - .def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); }) - .def_property_readonly( + // Install buffer protocol functions + PyType_Slot array_slots[] = { + {Py_bf_getbuffer, (void*)getbuffer}, + {Py_bf_releasebuffer, (void*)releasebuffer}, + {0, nullptr}}; + + nb::class_( + m, + "array", + R"pbdoc(An N-dimensional array object.)pbdoc", + nb::type_slots(array_slots), + nb::is_weak_referenceable()) + .def( + "__init__", + [](array* aptr, ArrayInitType v, std::optional t) { + new (aptr) array(create_array(v, t)); + }, + "val"_a, + "dtype"_a = nb::none(), + nb::sig( + "def __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)")) + .def_prop_ro( "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") - .def_property_readonly( - "ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") - .def_property_readonly( + .def_prop_ro("ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") + .def_prop_ro( "itemsize", &array::itemsize, R"pbdoc(The size of the array's datatype in bytes.)pbdoc") - .def_property_readonly( + .def_prop_ro( "nbytes", &array::nbytes, R"pbdoc(The number of bytes in the array.)pbdoc") - .def_property_readonly( + .def_prop_ro( "shape", - [](const array& a) { return py::tuple(py::cast(a.shape())); }, + [](const array& a) { return nb::tuple(nb::cast(a.shape())); }, R"pbdoc( The shape of the array as a Python tuple. Returns: tuple(int): A tuple containing the sizes of each dimension. )pbdoc") - .def_property_readonly( + .def_prop_ro( "dtype", &array::dtype, R"pbdoc( @@ -720,7 +548,7 @@ void init_array(py::module_& m) { "astype", &astype, "dtype"_a, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( Cast the array to a specified type. @@ -731,9 +559,9 @@ void init_array(py::module_& m) { Returns: array: The array with type ``dtype``. )pbdoc") - .def("__getitem__", mlx_get_item) - .def("__setitem__", mlx_set_item) - .def_property_readonly( + .def("__getitem__", mlx_get_item, nb::arg().none()) + .def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg()) + .def_prop_ro( "at", [](const array& a) { return ArrayAt(a); }, R"pbdoc( @@ -769,35 +597,34 @@ void init_array(py::module_& m) { "__len__", [](const array& a) { if (a.ndim() == 0) { - throw py::type_error("len() 0-dimensional array."); + throw nb::type_error("len() 0-dimensional array."); } return a.shape(0); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) - .def(py::pickle( - [](array& a) { // __getstate__ + .def( + "__getstate__", + [](const array& a) { if (a.dtype() == bfloat16) { - throw std::runtime_error( - "[array.__getstate__] Not supported for bfloat16."); } - return py::array(buffer_info(a)); - }, - [](py::array npa) { // __setstate__ - if (not py::isinstance(npa)) { - throw std::runtime_error( - "[array.__setstate__] Received invalid state."); - } - return np_array_to_mlx(npa, std::nullopt); - })) + return mlx_to_np_array(a); + }) + .def( + "__setstate__", + [](array& arr, + const nb::ndarray& state) { + new (&arr) array(nd_array_to_mlx(state, std::nullopt)); + }) .def("__copy__", [](const array& self) { return array(self); }) .def( "__deepcopy__", - [](const array& self, py::dict) { return array(self); }, + [](const array& self, nb::dict) { return array(self); }, "memo"_a) .def( "__add__", [](const array& a, const ScalarOrArray v) { - return add(a, to_array(v, a.dtype())); + auto b = to_array(v, a.dtype()); + return add(a, b); }, "other"_a) .def( @@ -806,7 +633,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__radd__", [](const array& a, const ScalarOrArray v) { @@ -825,7 +653,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__rsub__", [](const array& a, const ScalarOrArray v) { @@ -844,7 +673,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__rmul__", [](const array& a, const ScalarOrArray v) { @@ -867,7 +697,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(divide(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__rtruediv__", [](const array& a, const ScalarOrArray v) { @@ -898,7 +729,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__rfloordiv__", [](const array& a, const ScalarOrArray v) { @@ -918,7 +750,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__rmod__", [](const array& a, const ScalarOrArray v) { @@ -962,12 +795,12 @@ void init_array(py::module_& m) { }, "other"_a) .def("__neg__", [](const array& a) { return -a; }) - .def("__bool__", [](array& a) { return py::bool_(to_scalar(a)); }) + .def("__bool__", [](array& a) { return nb::bool_(to_scalar(a)); }) .def( "__repr__", [](array& a) { if (!a.is_evaled()) { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; a.eval(); } std::ostringstream os; @@ -984,7 +817,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(matmul(a, other)); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__pow__", [](const array& a, const ScalarOrArray v) { @@ -1003,7 +837,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__invert__", [](const array& a) { @@ -1047,7 +882,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(logical_and(a, b)); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "__or__", [](const array& a, const ScalarOrArray v) { @@ -1078,7 +914,8 @@ void init_array(py::module_& m) { a.overwrite_descriptor(logical_or(a, b)); return a; }, - "other"_a) + "other"_a, + nb::rv_policy::none) .def( "flatten", [](const array& a, @@ -1089,27 +926,27 @@ void init_array(py::module_& m) { }, "start_axis"_a = 0, "end_axis"_a = -1, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), R"pbdoc( See :func:`flatten`. )pbdoc") .def( "reshape", - [](const array& a, py::args shape, StreamOrDevice s) { - if (shape.size() == 1) { - py::object arg = shape[0]; - if (!py::isinstance(arg)) { - return reshape(a, py::cast>(arg), s); - } + [](const array& a, nb::args shape_, StreamOrDevice s) { + std::vector shape; + if (!nb::isinstance(shape_[0])) { + shape = nb::cast>(shape_[0]); + } else { + shape = nb::cast>(shape_); } - return reshape(a, py::cast>(shape), s); + return reshape(a, shape, s); }, - py::kw_only(), - "stream"_a = none, + "shape"_a, + "stream"_a = nb::none(), R"pbdoc( Equivalent to :func:`reshape` but the shape can be passed either as a - tuple or as separate arguments. + :obj:`tuple` or as separate arguments. See :func:`reshape` for full documentation. )pbdoc") @@ -1124,85 +961,85 @@ void init_array(py::module_& m) { return squeeze(a, std::get>(v), s); } }, - "axis"_a = none, - py::kw_only(), - "stream"_a = none, + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), R"pbdoc( See :func:`squeeze`. )pbdoc") .def( "abs", &mlx::core::abs, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`abs`.") .def( "__abs__", [](const array& a) { return abs(a); }, "See :func:`abs`.") .def( "square", &square, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`square`.") .def( "sqrt", &mlx::core::sqrt, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`sqrt`.") .def( "rsqrt", &rsqrt, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`rsqrt`.") .def( "reciprocal", &reciprocal, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`reciprocal`.") .def( "exp", &mlx::core::exp, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`exp`.") .def( "log", &mlx::core::log, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`log`.") .def( "log2", &mlx::core::log2, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`log2`.") .def( "log10", &mlx::core::log10, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`log10`.") .def( "sin", &mlx::core::sin, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`sin`.") .def( "cos", &mlx::core::cos, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`cos`.") .def( "log1p", &mlx::core::log1p, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`log1p`.") .def( "all", @@ -1212,10 +1049,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`all`.") .def( "any", @@ -1225,51 +1062,50 @@ void init_array(py::module_& m) { StreamOrDevice s) { return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`any`.") .def( "moveaxis", &moveaxis, "source"_a, "destination"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`moveaxis`.") .def( "swapaxes", &swapaxes, "axis1"_a, "axis2"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`swapaxes`.") .def( "transpose", - [](const array& a, py::args axes, StreamOrDevice s) { - if (axes.size() > 0) { - if (axes.size() == 1) { - py::object arg = axes[0]; - if (!py::isinstance(arg)) { - return transpose(a, py::cast>(arg), s); - } - } - return transpose(a, py::cast>(axes), s); - } else { + [](const array& a, nb::args axes_, StreamOrDevice s) { + if (axes_.size() == 0) { return transpose(a, s); } + std::vector axes; + if (!nb::isinstance(axes_[0])) { + axes = nb::cast>(axes_[0]); + } else { + axes = nb::cast>(axes_); + } + return transpose(a, axes, s); }, - py::kw_only(), - "stream"_a = none, + "axes"_a, + "stream"_a = nb::none(), R"pbdoc( Equivalent to :func:`transpose` but the axes can be passed either as a tuple or as separate arguments. See :func:`transpose` for full documentation. )pbdoc") - .def_property_readonly( + .def_prop_ro( "T", [](const array& a) { return transpose(a); }, "Equivalent to calling ``self.transpose()`` with no arguments.") @@ -1281,10 +1117,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`sum`.") .def( "prod", @@ -1294,10 +1130,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`prod`.") .def( "min", @@ -1307,10 +1143,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`min`.") .def( "max", @@ -1320,10 +1156,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`max`.") .def( "logsumexp", @@ -1333,10 +1169,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`logsumexp`.") .def( "mean", @@ -1346,10 +1182,10 @@ void init_array(py::module_& m) { StreamOrDevice s) { return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`mean`.") .def( "var", @@ -1360,11 +1196,11 @@ void init_array(py::module_& m) { StreamOrDevice s) { return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, "ddof"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`var`.") .def( "split", @@ -1381,8 +1217,8 @@ void init_array(py::module_& m) { }, "indices_or_sections"_a, "axis"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`split`.") .def( "argmin", @@ -1398,8 +1234,8 @@ void init_array(py::module_& m) { }, "axis"_a = std::nullopt, "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`argmin`.") .def( "argmax", @@ -1413,10 +1249,10 @@ void init_array(py::module_& m) { return argmax(a, keepdims, s); } }, - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`argmax`.") .def( "cumsum", @@ -1433,11 +1269,11 @@ void init_array(py::module_& m) { return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "axis"_a = none, - py::kw_only(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), "See :func:`cumsum`.") .def( "cumprod", @@ -1454,11 +1290,11 @@ void init_array(py::module_& m) { return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "axis"_a = none, - py::kw_only(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), "See :func:`cumprod`.") .def( "cummax", @@ -1475,11 +1311,11 @@ void init_array(py::module_& m) { return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "axis"_a = none, - py::kw_only(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), "See :func:`cummax`.") .def( "cummin", @@ -1496,21 +1332,20 @@ void init_array(py::module_& m) { return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "axis"_a = none, - py::kw_only(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), "See :func:`cummin`.") .def( "round", [](const array& a, int decimals, StreamOrDevice s) { return round(a, decimals, s); }, - py::pos_only(), "decimals"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), "See :func:`round`.") .def( "diagonal", @@ -1522,14 +1357,14 @@ void init_array(py::module_& m) { "offset"_a = 0, "axis1"_a = 0, "axis2"_a = 1, - "stream"_a = none, + "stream"_a = nb::none(), "See :func:`diagonal`.") .def( "diag", [](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); }, "k"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), R"pbdoc( Extract a diagonal or construct a diagonal matrix. )pbdoc"); diff --git a/python/src/buffer.h b/python/src/buffer.h new file mode 100644 index 000000000..2118e7450 --- /dev/null +++ b/python/src/buffer.h @@ -0,0 +1,122 @@ +// Copyright © 2024 Apple Inc. +#pragma once +#include + +#include + +#include "mlx/array.h" +#include "mlx/utils.h" + +// Only defined in >= Python 3.9 +// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3 +#ifndef Py_bf_getbuffer +#define Py_bf_getbuffer 1 +#define Py_bf_releasebuffer 2 +#endif + +namespace nb = nanobind; +using namespace mlx::core; + +std::string buffer_format(const array& a) { + // https://docs.python.org/3.10/library/struct.html#format-characters + switch (a.dtype()) { + case bool_: + return "?"; + case uint8: + return "B"; + case uint16: + return "H"; + case uint32: + return "I"; + case uint64: + return "Q"; + case int8: + return "b"; + case int16: + return "h"; + case int32: + return "i"; + case int64: + return "q"; + case float16: + return "e"; + case float32: + return "f"; + case bfloat16: + return "B"; + case complex64: + return "Zf\0"; + default: { + std::ostringstream os; + os << "bad dtype: " << a.dtype(); + throw std::runtime_error(os.str()); + } + } +} + +struct buffer_info { + std::string format; + std::vector shape; + std::vector strides; + + buffer_info( + const std::string& format, + std::vector shape_in, + std::vector strides_in) + : format(format), + shape(std::move(shape_in)), + strides(std::move(strides_in)) {} + + buffer_info(const buffer_info&) = delete; + buffer_info& operator=(const buffer_info&) = delete; + + buffer_info(buffer_info&& other) noexcept { + (*this) = std::move(other); + } + + buffer_info& operator=(buffer_info&& rhs) noexcept { + format = std::move(rhs.format); + shape = std::move(rhs.shape); + strides = std::move(rhs.strides); + return *this; + } +}; + +extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { + std::memset(view, 0, sizeof(Py_buffer)); + auto a = nb::cast(nb::handle(obj)); + + if (!a.is_evaled()) { + nb::gil_scoped_release nogil; + a.eval(); + } + + std::vector shape(a.shape().begin(), a.shape().end()); + std::vector strides(a.strides().begin(), a.strides().end()); + for (auto& s : strides) { + s *= a.itemsize(); + } + buffer_info* info = + new buffer_info(buffer_format(a), std::move(shape), std::move(strides)); + + view->obj = obj; + view->ndim = a.ndim(); + view->internal = info; + view->buf = a.data(); + view->itemsize = a.itemsize(); + view->len = a.size(); + view->readonly = false; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format.c_str()); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->strides = info->strides.data(); + view->shape = info->shape.data(); + } + Py_INCREF(view->obj); + return 0; +} + +extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) { + delete (buffer_info*)view->internal; +} diff --git a/python/src/constants.cpp b/python/src/constants.cpp index 94658b586..17e9ccd69 100644 --- a/python/src/constants.cpp +++ b/python/src/constants.cpp @@ -1,11 +1,11 @@ -// init_constants.cpp +// Copyright © 2023-2024 Apple Inc. -#include +#include #include -namespace py = pybind11; +namespace nb = nanobind; -void init_constants(py::module_& m) { +void init_constants(nb::module_& m) { m.attr("Inf") = std::numeric_limits::infinity(); m.attr("Infinity") = std::numeric_limits::infinity(); m.attr("NAN") = NAN; @@ -19,6 +19,6 @@ void init_constants(py::module_& m) { m.attr("inf") = std::numeric_limits::infinity(); m.attr("infty") = std::numeric_limits::infinity(); m.attr("nan") = NAN; - m.attr("newaxis") = pybind11::none(); + m.attr("newaxis") = nb::none(); m.attr("pi") = 3.1415926535897932384626433; -} \ No newline at end of file +} diff --git a/python/src/convert.cpp b/python/src/convert.cpp new file mode 100644 index 000000000..740dbb664 --- /dev/null +++ b/python/src/convert.cpp @@ -0,0 +1,155 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "python/src/convert.h" + +namespace nanobind { +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +template <> +struct ndarray_traits { + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; +}; + +static constexpr dlpack::dtype bfloat16{4, 16, 1}; +}; // namespace nanobind + +template +array nd_array_to_mlx_contiguous( + nb::ndarray nd_array, + const std::vector& shape, + Dtype dtype) { + // Make a copy of the numpy buffer + // Get buffer ptr pass to array constructor + auto data_ptr = nd_array.data(); + return array(static_cast(data_ptr), shape, dtype); +} + +array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional dtype) { + // Compute the shape and size + std::vector shape; + for (int i = 0; i < nd_array.ndim(); i++) { + shape.push_back(nd_array.shape(i)); + } + auto type = nd_array.dtype(); + + // Copy data and make array + if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(bool_)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(uint8)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(uint16)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(uint32)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(uint64)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(int8)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(int16)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(int32)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(int64)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(float16)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(bfloat16)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(float32)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(float32)); + } else if (type == nb::dtype>()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(complex64)); + } else if (type == nb::dtype>()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(complex64)); + } else { + throw std::invalid_argument("Cannot convert numpy array to mlx array."); + } +} + +template +nb::ndarray mlx_to_nd_array( + array a, + std::optional t = {}) { + // Eval if not already evaled + if (!a.is_evaled()) { + nb::gil_scoped_release nogil; + a.eval(); + } + std::vector shape(a.shape().begin(), a.shape().end()); + std::vector strides(a.strides().begin(), a.strides().end()); + return nb::ndarray( + a.data(), + a.ndim(), + shape.data(), + nb::handle(), + strides.data(), + t.value_or(nb::dtype())); +} + +template +nb::ndarray mlx_to_nd_array(const array& a) { + switch (a.dtype()) { + case bool_: + return mlx_to_nd_array(a); + case uint8: + return mlx_to_nd_array(a); + case uint16: + return mlx_to_nd_array(a); + case uint32: + return mlx_to_nd_array(a); + case uint64: + return mlx_to_nd_array(a); + case int8: + return mlx_to_nd_array(a); + case int16: + return mlx_to_nd_array(a); + case int32: + return mlx_to_nd_array(a); + case int64: + return mlx_to_nd_array(a); + case float16: + return mlx_to_nd_array(a); + case bfloat16: + return mlx_to_nd_array(a, nb::bfloat16); + case float32: + return mlx_to_nd_array(a); + case complex64: + return mlx_to_nd_array>(a); + } +} + +nb::ndarray mlx_to_np_array(const array& a) { + return mlx_to_nd_array(a); +} diff --git a/python/src/convert.h b/python/src/convert.h new file mode 100644 index 000000000..36e868a7d --- /dev/null +++ b/python/src/convert.h @@ -0,0 +1,16 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include +#include + +#include "mlx/array.h" + +namespace nb = nanobind; +using namespace mlx::core; + +array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional dtype); +nb::ndarray mlx_to_np_array(const array& a); diff --git a/python/src/device.cpp b/python/src/device.cpp index c88144520..1d0c38b74 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -1,32 +1,34 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include -#include +#include +#include #include "mlx/device.h" #include "mlx/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; -void init_device(py::module_& m) { - auto device_class = py::class_( +void init_device(nb::module_& m) { + auto device_class = nb::class_( m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); - py::enum_(m, "DeviceType") + nb::enum_(m, "DeviceType") .value("cpu", Device::DeviceType::cpu) .value("gpu", Device::DeviceType::gpu) .export_values() - .def( - "__eq__", - [](const Device::DeviceType& d1, const Device& d2) { - return d1 == d2; - }, - py::prepend()); + .def("__eq__", [](const Device::DeviceType& d, const nb::object& other) { + if (!nb::isinstance(other) && + !nb::isinstance(other)) { + return false; + } + return d == nb::cast(other); + }); - device_class.def(py::init(), "type"_a, "index"_a = 0) - .def_readonly("type", &Device::type) + device_class.def(nb::init(), "type"_a, "index"_a = 0) + .def_ro("type", &Device::type) .def( "__repr__", [](const Device& d) { @@ -34,11 +36,15 @@ void init_device(py::module_& m) { os << d; return os.str(); }) - .def("__eq__", [](const Device& d1, const Device& d2) { - return d1 == d2; + .def("__eq__", [](const Device& d, const nb::object& other) { + if (!nb::isinstance(other) && + !nb::isinstance(other)) { + return false; + } + return d == nb::cast(other); }); - py::implicitly_convertible(); + nb::implicitly_convertible(); m.def( "default_device", diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a667e8e4..20f0f7033 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -1,20 +1,17 @@ // Copyright © 2023-2024 Apple Inc. -#include -#include +#include +#include +#include #include "mlx/fast.h" #include "mlx/ops.h" -#include "python/src/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; -void init_extensions(py::module_& parent_module) { - py::options options; - options.disable_function_signatures(); - +void init_fast(nb::module_& parent_module) { auto m = parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); @@ -31,15 +28,15 @@ void init_extensions(py::module_& parent_module) { }, "a"_a, "dims"_a, - py::kw_only(), + nb::kw_only(), "traditional"_a, "base"_a, "scale"_a, "offset"_a, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array - Apply rotary positional encoding to the input. Args: @@ -70,30 +67,34 @@ void init_extensions(py::module_& parent_module) { "q"_a, "k"_a, "v"_a, - py::kw_only(), + nb::kw_only(), "scale"_a, - "mask"_a = none, - "stream"_a = none, + "mask"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array + A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. - A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V. - Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + Supports: + * [Multi-Head Attention](https://arxiv.org/abs/1706.03762) + * [Grouped Query Attention](https://arxiv.org/abs/2305.13245) + * [Multi-Query Attention](https://arxiv.org/abs/1911.02150). - This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + Note: The softmax operation is performed in ``float32`` regardless of + input precision. - Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). - Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` + and ``v`` inputs should not be pre-tiled to match ``q``. - Args: - q (array): Input query array. - k (array): Input keys array. - v (array): Input values array. - scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (array, optional): An additive mask to apply to the query-key scores. + Args: + q (array): Input query array. + k (array): Input keys array. + v (array): Input values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive mask to apply to the query-key scores. - Returns: - array: The output array. - - )pbdoc"); + Returns: + array: The output array. + )pbdoc"); } diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 42ad37633..3b1007fe2 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -1,19 +1,20 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include - -#include "python/src/utils.h" +#include +#include +#include +#include +#include #include "mlx/fft.h" #include "mlx/ops.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; -void init_fft(py::module_& parent_module) { +void init_fft(nb::module_& parent_module) { auto m = parent_module.def_submodule( "fft", "mlx.core.fft: Fast Fourier Transforms."); m.def( @@ -29,9 +30,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "n"_a = none, + "n"_a = nb::none(), "axis"_a = -1, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( One dimensional discrete Fourier Transform. @@ -59,9 +60,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "n"_a = none, + "n"_a = nb::none(), "axis"_a = -1, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( One dimensional inverse discrete Fourier Transform. @@ -95,9 +96,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = std::vector{-2, -1}, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a.none() = std::vector{-2, -1}, + "stream"_a = nb::none(), R"pbdoc( Two dimensional discrete Fourier Transform. @@ -132,9 +133,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = std::vector{-2, -1}, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a.none() = std::vector{-2, -1}, + "stream"_a = nb::none(), R"pbdoc( Two dimensional inverse discrete Fourier Transform. @@ -169,9 +170,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = none, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( n-dimensional discrete Fourier Transform. @@ -207,9 +208,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = none, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( n-dimensional inverse discrete Fourier Transform. @@ -239,9 +240,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "n"_a = none, + "n"_a = nb::none(), "axis"_a = -1, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( One dimensional discrete Fourier Transform on a real input. @@ -274,9 +275,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "n"_a = none, + "n"_a = nb::none(), "axis"_a = -1, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfft`. @@ -314,9 +315,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = std::vector{-2, -1}, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a.none() = std::vector{-2, -1}, + "stream"_a = nb::none(), R"pbdoc( Two dimensional real discrete Fourier Transform. @@ -357,9 +358,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = std::vector{-2, -1}, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a.none() = std::vector{-2, -1}, + "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfft2`. @@ -400,9 +401,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = none, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( n-dimensional real discrete Fourier Transform. @@ -443,9 +444,9 @@ void init_fft(py::module_& parent_module) { } }, "a"_a, - "s"_a = none, - "axes"_a = none, - "stream"_a = none, + "s"_a = nb::none(), + "axes"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( The inverse of :func:`rfftn`. diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 74fb6b695..a0682afed 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -7,19 +7,19 @@ #include "mlx/ops.h" -bool is_none_slice(const py::slice& in_slice) { +bool is_none_slice(const nb::slice& in_slice) { return ( - py::getattr(in_slice, "start").is_none() && - py::getattr(in_slice, "stop").is_none() && - py::getattr(in_slice, "step").is_none()); + nb::getattr(in_slice, "start").is_none() && + nb::getattr(in_slice, "stop").is_none() && + nb::getattr(in_slice, "step").is_none()); } -int get_slice_int(py::object obj, int default_val) { +int get_slice_int(nb::object obj, int default_val) { if (!obj.is_none()) { - if (!py::isinstance(obj)) { + if (!nb::isinstance(obj)) { throw std::invalid_argument("Slice indices must be integers or None."); } - return py::cast(py::cast(obj)); + return nb::cast(nb::cast(obj)); } return default_val; } @@ -28,7 +28,7 @@ void get_slice_params( int& starts, int& ends, int& strides, - const py::slice& in_slice, + const nb::slice& in_slice, int axis_size) { // Following numpy's convention // Assume n is the number of elements in the dimension being sliced. @@ -36,26 +36,26 @@ void get_slice_params( // k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for // k < 0 . If k is not given it defaults to 1 - strides = get_slice_int(py::getattr(in_slice, "step"), 1); + strides = get_slice_int(nb::getattr(in_slice, "step"), 1); starts = get_slice_int( - py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); + nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); ends = get_slice_int( - py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); + nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); } -array get_int_index(py::object idx, int axis_size) { - int idx_ = py::cast(idx); +array get_int_index(nb::object idx, int axis_size) { + int idx_ = nb::cast(idx); idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; return array(idx_, uint32); } -bool is_valid_index_type(const py::object& obj) { - return py::isinstance(obj) || py::isinstance(obj) || - py::isinstance(obj) || obj.is_none() || py::ellipsis().is(obj); +bool is_valid_index_type(const nb::object& obj) { + return nb::isinstance(obj) || nb::isinstance(obj) || + nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj); } -array mlx_get_item_slice(const array& src, const py::slice& in_slice) { +array mlx_get_item_slice(const array& src, const nb::slice& in_slice) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) { return take(src, indices, 0); } -array mlx_get_item_int(const array& src, const py::int_& idx) { +array mlx_get_item_int(const array& src, const nb::int_& idx) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) { array mlx_gather_nd( array src, - const std::vector& indices, + const std::vector& indices, bool gather_first, int& max_dims) { max_dims = 0; @@ -117,9 +117,10 @@ array mlx_gather_nd( for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (py::isinstance(idx)) { + if (nb::isinstance(idx)) { int start, end, stride; - get_slice_params(start, end, stride, idx, src.shape(i)); + get_slice_params( + start, end, stride, nb::cast(idx), src.shape(i)); // Handle negative indices start = (start < 0) ? start + src.shape(i) : start; @@ -128,10 +129,10 @@ array mlx_gather_nd( gather_indices.push_back(arange(start, end, stride, uint32)); num_slices++; is_slice[i] = true; - } else if (py::isinstance(idx)) { + } else if (nb::isinstance(idx)) { gather_indices.push_back(get_int_index(idx, src.shape(i))); - } else if (py::isinstance(idx)) { - auto arr = py::cast(idx); + } else if (nb::isinstance(idx)) { + auto arr = nb::cast(idx); max_dims = std::max(static_cast(arr.ndim()), max_dims); gather_indices.push_back(arr); } @@ -185,7 +186,7 @@ array mlx_gather_nd( return src; } -array mlx_get_item_nd(array src, const py::tuple& entries) { +array mlx_get_item_nd(array src, const nb::tuple& entries) { // No indices make this a noop if (entries.size() == 0) { return src; @@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { // 3. Calculate the remaining slices and reshapes // Ellipsis handling - std::vector indices; + std::vector indices; { int non_none_indices_before = 0; int non_none_indices_after = 0; - std::vector r_indices; + std::vector r_indices; int i = 0; for (; i < entries.size(); i++) { auto idx = entries[i]; @@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { throw std::invalid_argument( "Cannot index mlx array using the given type yet"); } - if (!py::ellipsis().is(idx)) { + if (!nb::ellipsis().is(idx)) { indices.push_back(idx); non_none_indices_before += !idx.is_none(); } else { @@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { throw std::invalid_argument( "Cannot index mlx array using the given type yet"); } - if (py::ellipsis().is(idx)) { + if (nb::ellipsis().is(idx)) { throw std::invalid_argument( "An index can only have a single ellipsis (...)"); } @@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { for (int axis = non_none_indices_before; axis < src.ndim() - non_none_indices_after; axis++) { - indices.push_back(py::slice(0, src.shape(axis), 1)); + indices.push_back(nb::slice(0, src.shape(axis), 1)); } indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); } @@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { // // Check whether we have arrays or integer indices and delegate to gather_nd // after removing the slices at the end and all Nones. - std::vector remaining_indices; + std::vector remaining_indices; bool have_array = false; { // First check whether the results of gather are going to be 1st or @@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (py::isinstance(idx) || py::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { if (have_array && have_non_array) { gather_first = true; break; @@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { // Then find the last array for (last_array = indices.size() - 1; last_array >= 0; last_array--) { auto& idx = indices[last_array]; - if (py::isinstance(idx) || py::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { break; } } - std::vector gather_indices; + std::vector gather_indices; for (int i = 0; i <= last_array; i++) { auto& idx = indices[i]; if (!idx.is_none()) { @@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { if (gather_first) { for (int i = 0; i < max_dims; i++) { remaining_indices.push_back( - py::slice(py::none(), py::none(), py::none())); + nb::slice(nb::none(), nb::none(), nb::none())); } for (int i = 0; i < last_array; i++) { auto& idx = indices[i]; if (idx.is_none()) { remaining_indices.push_back(indices[i]); - } else if (py::isinstance(idx)) { + } else if (nb::isinstance(idx)) { remaining_indices.push_back( - py::slice(py::none(), py::none(), py::none())); + nb::slice(nb::none(), nb::none(), nb::none())); } } for (int i = last_array + 1; i < indices.size(); i++) { @@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { } else { for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (py::isinstance(idx) || py::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { break; } else if (idx.is_none()) { remaining_indices.push_back(idx); } else { remaining_indices.push_back( - py::slice(py::none(), py::none(), py::none())); + nb::slice(nb::none(), nb::none(), nb::none())); } } for (int i = 0; i < max_dims; i++) { remaining_indices.push_back( - py::slice(py::none(), py::none(), py::none())); + nb::slice(nb::none(), nb::none(), nb::none())); } for (int i = last_array + 1; i < indices.size(); i++) { remaining_indices.push_back(indices[i]); @@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { for (auto& idx : remaining_indices) { if (!idx.is_none()) { get_slice_params( - starts[axis], ends[axis], strides[axis], idx, ends[axis]); + starts[axis], + ends[axis], + strides[axis], + nb::cast(idx), + ends[axis]); axis++; } } @@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) { return src; } -array mlx_get_item(const array& src, const py::object& obj) { - if (py::isinstance(obj)) { - return mlx_get_item_slice(src, obj); - } else if (py::isinstance(obj)) { - return mlx_get_item_array(src, py::cast(obj)); - } else if (py::isinstance(obj)) { - return mlx_get_item_int(src, obj); - } else if (py::isinstance(obj)) { - return mlx_get_item_nd(src, obj); +array mlx_get_item(const array& src, const nb::object& obj) { + if (nb::isinstance(obj)) { + return mlx_get_item_slice(src, nb::cast(obj)); + } else if (nb::isinstance(obj)) { + return mlx_get_item_array(src, nb::cast(obj)); + } else if (nb::isinstance(obj)) { + return mlx_get_item_int(src, nb::cast(obj)); + } else if (nb::isinstance(obj)) { + return mlx_get_item_nd(src, nb::cast(obj)); + } else if (nb::isinstance(obj)) { + return src; } else if (obj.is_none()) { std::vector s(1, 1); s.insert(s.end(), src.shape().begin(), src.shape().end()); @@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) { std::tuple, array, std::vector> mlx_scatter_args_int( const array& src, - const py::int_& idx, + const nb::int_& idx, const array& update) { if (src.ndim() == 0) { throw std::invalid_argument( @@ -446,7 +453,7 @@ std::tuple, array, std::vector> mlx_scatter_args_array( std::tuple, array, std::vector> mlx_scatter_args_slice( const array& src, - const py::slice& in_slice, + const nb::slice& in_slice, const array& update) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { @@ -478,9 +485,9 @@ std::tuple, array, std::vector> mlx_scatter_args_slice( std::tuple, array, std::vector> mlx_scatter_args_nd( const array& src, - const py::tuple& entries, + const nb::tuple& entries, const array& update) { - std::vector indices; + std::vector indices; int non_none_indices = 0; // Expand ellipses into a series of ':' slices @@ -494,7 +501,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( if (!is_valid_index_type(idx)) { throw std::invalid_argument( "Cannot index mlx array using the given type yet"); - } else if (!py::ellipsis().is(idx)) { + } else if (!nb::ellipsis().is(idx)) { if (!has_ellipsis) { indices_before++; non_none_indices_before += !idx.is_none(); @@ -514,7 +521,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( axis < src.ndim() - non_none_indices_after; axis++) { indices.insert( - indices.begin() + indices_before, py::slice(0, src.shape(axis), 1)); + indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1)); } non_none_indices = src.ndim(); } else { @@ -549,15 +556,15 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( bool have_array = false; bool have_non_array = false; for (auto& idx : indices) { - if (py::isinstance(idx) || idx.is_none()) { + if (nb::isinstance(idx) || idx.is_none()) { have_non_array = have_array; num_slices++; - } else if (py::isinstance(idx)) { + } else if (nb::isinstance(idx)) { have_array = true; if (have_array && have_non_array) { arrays_first = true; } - max_dim = std::max(py::cast(idx).ndim(), max_dim); + max_dim = std::max(nb::cast(idx).ndim(), max_dim); num_arrays++; } } @@ -569,10 +576,11 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( int ax = 0; for (int i = 0; i < indices.size(); ++i) { auto& pyidx = indices[i]; - if (py::isinstance(pyidx)) { + if (nb::isinstance(pyidx)) { int start, end, stride; auto axis_size = src.shape(ax++); - get_slice_params(start, end, stride, pyidx, axis_size); + get_slice_params( + start, end, stride, nb::cast(pyidx), axis_size); // Handle negative indices start = (start < 0) ? start + axis_size : start; @@ -584,13 +592,13 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( slice_num++; idx_shape[loc] = idx.size(); arr_indices.push_back(reshape(idx, idx_shape)); - } else if (py::isinstance(pyidx)) { + } else if (nb::isinstance(pyidx)) { arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); } else if (pyidx.is_none()) { slice_num++; - } else if (py::isinstance(pyidx)) { + } else if (nb::isinstance(pyidx)) { ax++; - auto idx = py::cast(pyidx); + auto idx = nb::cast(pyidx); std::vector idx_shape; if (!arrays_first) { idx_shape.insert(idx_shape.end(), slice_num, 1); @@ -629,24 +637,24 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( std::tuple, array, std::vector> mlx_compute_scatter_args( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto vals = to_array(v, src.dtype()); - if (py::isinstance(obj)) { - return mlx_scatter_args_slice(src, obj, vals); - } else if (py::isinstance(obj)) { - return mlx_scatter_args_array(src, py::cast(obj), vals); - } else if (py::isinstance(obj)) { - return mlx_scatter_args_int(src, obj, vals); - } else if (py::isinstance(obj)) { - return mlx_scatter_args_nd(src, obj, vals); + if (nb::isinstance(obj)) { + return mlx_scatter_args_slice(src, nb::cast(obj), vals); + } else if (nb::isinstance(obj)) { + return mlx_scatter_args_array(src, nb::cast(obj), vals); + } else if (nb::isinstance(obj)) { + return mlx_scatter_args_int(src, nb::cast(obj), vals); + } else if (nb::isinstance(obj)) { + return mlx_scatter_args_nd(src, nb::cast(obj), vals); } else if (obj.is_none()) { return {{}, broadcast_to(vals, src.shape()), {}}; } throw std::invalid_argument("Cannot index mlx array using the given type."); } -void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { +void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { auto out = scatter(src, indices, updates, axes); @@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { array mlx_add_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { @@ -670,7 +678,7 @@ array mlx_add_item( array mlx_subtract_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { @@ -682,7 +690,7 @@ array mlx_subtract_item( array mlx_multiply_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { @@ -694,7 +702,7 @@ array mlx_multiply_item( array mlx_divide_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { @@ -706,7 +714,7 @@ array mlx_divide_item( array mlx_maximum_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { @@ -718,7 +726,7 @@ array mlx_maximum_item( array mlx_minimum_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); if (indices.size() > 0) { diff --git a/python/src/indexing.h b/python/src/indexing.h index 0ddea859e..91ea17233 100644 --- a/python/src/indexing.h +++ b/python/src/indexing.h @@ -1,38 +1,38 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once -#include +#include #include "mlx/array.h" #include "python/src/utils.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlx::core; -array mlx_get_item(const array& src, const py::object& obj); -void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); +array mlx_get_item(const array& src, const nb::object& obj); +void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v); array mlx_add_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); array mlx_subtract_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); array mlx_multiply_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); array mlx_divide_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); array mlx_maximum_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); array mlx_minimum_item( const array& src, - const py::object& obj, + const nb::object& obj, const ScalarOrArray& v); diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 92b80b9eb..a6a86e414 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -1,32 +1,29 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include -#include -#include +#include +#include +#include +#include +#include #include "mlx/linalg.h" -#include "python/src/load.h" -#include "python/src/utils.h" - -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; using namespace mlx::core::linalg; namespace { -py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { +nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { const auto result = svd(a, s); - return py::make_tuple(result.at(0), result.at(1), result.at(2)); + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); } } // namespace -void init_linalg(py::module_& parent_module) { - py::options options; - options.disable_function_signatures(); - +void init_linalg(nb::module_& parent_module) { auto m = parent_module.def_submodule( "linalg", "mlx.core.linalg: linear algebra routines."); @@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) { return norm(a, ord, axis, keepdims, stream); } }, - "a"_a, - py::pos_only(), - "ord"_a = none, - "axis"_a = none, + nb::arg(), + "ord"_a = nb::none(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - Matrix or vector norm. This function computes vector or matrix norms depending on the value of @@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) { "qr", &qr, "a"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"), R"pbdoc( - qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array) - The QR factorization of the input matrix. This function supports arrays with at least 2 dimensions. The matrices @@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) { "svd", &svd_helper, "a"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"), R"pbdoc( - svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array) - The Singular Value Decomposition (SVD) of the input matrix. This function supports arrays with at least 2 dimensions. When the input @@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) { "inv", &inv, "a"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array - Compute the inverse of a square matrix. This function supports arrays with at least 2 dimensions. When the input diff --git a/python/src/load.cpp b/python/src/load.cpp index 9b6a6861e..46eb5932e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -1,8 +1,6 @@ -// Copyright © 2023 Apple Inc. - -#include -#include +// Copyright © 2023-2024 Apple Inc. +#include #include #include #include @@ -16,39 +14,39 @@ #include "python/src/load.h" #include "python/src/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; /////////////////////////////////////////////////////////////////////////////// // Helpers /////////////////////////////////////////////////////////////////////////////// -bool is_istream_object(const py::object& file) { - return py::hasattr(file, "readinto") && py::hasattr(file, "seek") && - py::hasattr(file, "tell") && py::hasattr(file, "closed"); +bool is_istream_object(const nb::object& file) { + return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") && + nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); } -bool is_ostream_object(const py::object& file) { - return py::hasattr(file, "write") && py::hasattr(file, "seek") && - py::hasattr(file, "tell") && py::hasattr(file, "closed"); +bool is_ostream_object(const nb::object& file) { + return nb::hasattr(file, "write") && nb::hasattr(file, "seek") && + nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); } -bool is_zip_file(const py::module_& zipfile, const py::object& file) { +bool is_zip_file(const nb::module_& zipfile, const nb::object& file) { if (is_istream_object(file)) { auto st_pos = file.attr("tell")(); - bool r = (zipfile.attr("is_zipfile")(file)).cast(); + bool r = nb::cast(zipfile.attr("is_zipfile")(file)); file.attr("seek")(st_pos, 0); return r; } - return zipfile.attr("is_zipfile")(file).cast(); + return nb::cast(zipfile.attr("is_zipfile")(file)); } class ZipFileWrapper { public: ZipFileWrapper( - const py::module_& zipfile, - const py::object& file, + const nb::module_& zipfile, + const nb::object& file, char mode = 'r', int compression = 0) : zipfile_module_(zipfile), @@ -63,10 +61,10 @@ class ZipFileWrapper { close_func_(zipfile_object_.attr("close")) {} std::vector namelist() const { - return files_list_.cast>(); + return nb::cast>(files_list_); } - py::object open(const std::string& key, char mode = 'r') { + nb::object open(const std::string& key, char mode = 'r') { // Following numpy : // https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47 if (mode == 'w') { @@ -76,12 +74,12 @@ class ZipFileWrapper { } private: - py::module_ zipfile_module_; - py::object zipfile_object_; - py::list files_list_; - py::object open_func_; - py::object read_func_; - py::object close_func_; + nb::module_ zipfile_module_; + nb::object zipfile_object_; + nb::list files_list_; + nb::object open_func_; + nb::object read_func_; + nb::object close_func_; }; /////////////////////////////////////////////////////////////////////////////// @@ -90,14 +88,14 @@ class ZipFileWrapper { class PyFileReader : public io::Reader { public: - PyFileReader(py::object file) + PyFileReader(nb::object file) : pyistream_(file), readinto_func_(file.attr("readinto")), seek_func_(file.attr("seek")), tell_func_(file.attr("tell")) {} ~PyFileReader() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; pyistream_.release().dec_ref(); readinto_func_.release().dec_ref(); @@ -108,8 +106,8 @@ class PyFileReader : public io::Reader { bool is_open() const override { bool out; { - py::gil_scoped_acquire gil; - out = !pyistream_.attr("closed").cast(); + nb::gil_scoped_acquire gil; + out = !nb::cast(pyistream_.attr("closed")); } return out; } @@ -117,7 +115,7 @@ class PyFileReader : public io::Reader { bool good() const override { bool out; { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; out = !pyistream_.is_none(); } return out; @@ -126,25 +124,24 @@ class PyFileReader : public io::Reader { size_t tell() const override { size_t out; { - py::gil_scoped_acquire gil; - out = tell_func_().cast(); + nb::gil_scoped_acquire gil; + out = nb::cast(tell_func_()); } return out; } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; seek_func_(off, (int)way); } void read(char* data, size_t n) override { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; + auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE); + nb::object bytes_read = readinto_func_(nb::handle(memview)); - py::object bytes_read = - readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); - - if (bytes_read.is_none() || py::cast(bytes_read) < n) { + if (bytes_read.is_none() || nb::cast(bytes_read) < n) { throw std::runtime_error("[load] Failed to read from python stream"); } } @@ -154,23 +151,23 @@ class PyFileReader : public io::Reader { } private: - py::object pyistream_; - py::object readinto_func_; - py::object seek_func_; - py::object tell_func_; + nb::object pyistream_; + nb::object readinto_func_; + nb::object seek_func_; + nb::object tell_func_; }; std::pair< std::unordered_map, std::unordered_map> -mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { - if (py::isinstance(file)) { // Assume .safetensors file path string - return load_safetensors(py::cast(file), s); +mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { + if (nb::isinstance(file)) { // Assume .safetensors file path string + return load_safetensors(nb::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately auto res = load_safetensors(std::make_shared(file), s); { - py::gil_scoped_release gil; + nb::gil_scoped_release gil; for (auto& [key, arr] : std::get<0>(res)) { arr.eval(); } @@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { "[load_safetensors] Input must be a file-like object, or string"); } -GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { - if (py::isinstance(file)) { // Assume .gguf file path string - return load_gguf(py::cast(file), s); +GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) { + if (nb::isinstance(file)) { // Assume .gguf file path string + return load_gguf(nb::cast(file), s); } throw std::invalid_argument("[load_gguf] Input must be a string"); } std::unordered_map mlx_load_npz_helper( - py::object file, + nb::object file, StreamOrDevice s) { - bool own_file = py::isinstance(file); + bool own_file = nb::isinstance(file); - py::module_ zipfile = py::module_::import("zipfile"); + nb::module_ zipfile = nb::module_::import_("zipfile"); if (!is_zip_file(zipfile, file)) { throw std::invalid_argument( "[load_npz] Input must be a zip file or a file-like object that can be " @@ -208,7 +205,7 @@ std::unordered_map mlx_load_npz_helper( ZipFileWrapper zipfile_object(zipfile, file); for (const std::string& st : zipfile_object.namelist()) { // Open zip file as a python file stream - py::object sub_file = zipfile_object.open(st); + nb::object sub_file = zipfile_object.open(st); // Create array from python fille stream auto arr = load(std::make_shared(sub_file), s); @@ -224,7 +221,7 @@ std::unordered_map mlx_load_npz_helper( // If we don't own the stream and it was passed to us, eval immediately if (!own_file) { - py::gil_scoped_release gil; + nb::gil_scoped_release gil; for (auto& [key, arr] : array_dict) { arr.eval(); } @@ -233,14 +230,14 @@ std::unordered_map mlx_load_npz_helper( return array_dict; } -array mlx_load_npy_helper(py::object file, StreamOrDevice s) { - if (py::isinstance(file)) { // Assume .npy file path string - return load(py::cast(file), s); +array mlx_load_npy_helper(nb::object file, StreamOrDevice s) { + if (nb::isinstance(file)) { // Assume .npy file path string + return load(nb::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately auto arr = load(std::make_shared(file), s); { - py::gil_scoped_release gil; + nb::gil_scoped_release gil; arr.eval(); } return arr; @@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) { } LoadOutputTypes mlx_load_helper( - py::object file, + nb::object file, std::optional format, bool return_metadata, StreamOrDevice s) { if (!format.has_value()) { std::string fname; - if (py::isinstance(file)) { - fname = py::cast(file); + if (nb::isinstance(file)) { + fname = nb::cast(file); } else if (is_istream_object(file)) { - fname = file.attr("name").cast(); + fname = nb::cast(file.attr("name")); } else { throw std::invalid_argument( "[load] Input must be a file-like object opened in binary mode, or string"); @@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper( class PyFileWriter : public io::Writer { public: - PyFileWriter(py::object file) + PyFileWriter(nb::object file) : pyostream_(file), write_func_(file.attr("write")), seek_func_(file.attr("seek")), tell_func_(file.attr("tell")) {} ~PyFileWriter() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; pyostream_.release().dec_ref(); write_func_.release().dec_ref(); @@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer { bool is_open() const override { bool out; { - py::gil_scoped_acquire gil; - out = !pyostream_.attr("closed").cast(); + nb::gil_scoped_acquire gil; + out = !nb::cast(pyostream_.attr("closed")); } return out; } @@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer { bool good() const override { bool out; { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; out = !pyostream_.is_none(); } return out; @@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer { size_t tell() const override { size_t out; { - py::gil_scoped_acquire gil; - out = tell_func_().cast(); + nb::gil_scoped_acquire gil; + out = nb::cast(tell_func_()); } return out; } void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) override { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; seek_func_(off, (int)way); } void write(const char* data, size_t n) override { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; - py::object bytes_written = - write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); + auto memview = + PyMemoryView_FromMemory(const_cast(data), n, PyBUF_READ); + nb::object bytes_written = write_func_(nb::handle(memview)); - if (bytes_written.is_none() || py::cast(bytes_written) < n) { + if (bytes_written.is_none() || nb::cast(bytes_written) < n) { throw std::runtime_error("[load] Failed to write to python stream"); } } @@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer { } private: - py::object pyostream_; - py::object write_func_; - py::object seek_func_; - py::object tell_func_; + nb::object pyostream_; + nb::object write_func_; + nb::object seek_func_; + nb::object tell_func_; }; -void mlx_save_helper(py::object file, array a) { - if (py::isinstance(file)) { - save(py::cast(file), a); +void mlx_save_helper(nb::object file, array a) { + if (nb::isinstance(file)) { + save(nb::cast(file), a); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { - py::gil_scoped_release gil; + nb::gil_scoped_release gil; save(writer, a); } @@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) { } void mlx_savez_helper( - py::object file_, - py::args args, - const py::kwargs& kwargs, + nb::object file_, + nb::args args, + const nb::kwargs& kwargs, bool compressed) { // Add .npz to the end of the filename if not already there - py::object file = file_; + nb::object file = file_; - if (py::isinstance(file_)) { - std::string fname = file_.cast(); + if (nb::isinstance(file_)) { + std::string fname = nb::cast(file_); // Add .npz to file name if it is not there if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") fname += ".npz"; - file = py::str(fname); + file = nb::cast(fname); } // Collect args and kwargs - auto arrays_dict = kwargs.cast>(); - auto arrays_list = args.cast>(); + auto arrays_dict = nb::cast>(kwargs); + auto arrays_list = nb::cast>(args); for (int i = 0; i < arrays_list.size(); i++) { std::string arr_name = "arr_" + std::to_string(i); @@ -426,9 +424,9 @@ void mlx_savez_helper( } // Create python ZipFile object depending on compression - py::module_ zipfile = py::module_::import("zipfile"); - int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast() - : zipfile.attr("ZIP_STORED").cast(); + nb::module_ zipfile = nb::module_::import_("zipfile"); + int compression = nb::cast( + compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED")); char mode = 'w'; ZipFileWrapper zipfile_object(zipfile, file, mode, compression); @@ -438,7 +436,7 @@ void mlx_savez_helper( auto py_ostream = zipfile_object.open(fname, 'w'); auto writer = std::make_shared(py_ostream); { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; save(writer, a); } } @@ -447,31 +445,31 @@ void mlx_savez_helper( } void mlx_save_safetensor_helper( - py::object file, - py::dict d, - std::optional m) { + nb::object file, + nb::dict d, + std::optional m) { std::unordered_map metadata_map; if (m) { try { metadata_map = - m.value().cast>(); - } catch (const py::cast_error& e) { + nb::cast>(m.value()); + } catch (const nb::cast_error& e) { throw std::invalid_argument( "[save_safetensors] Metadata must be a dictionary with string keys and values"); } } else { metadata_map = std::unordered_map(); } - auto arrays_map = d.cast>(); - if (py::isinstance(file)) { + auto arrays_map = nb::cast>(d); + if (nb::isinstance(file)) { { - py::gil_scoped_release nogil; - save_safetensors(py::cast(file), arrays_map, metadata_map); + nb::gil_scoped_release nogil; + save_safetensors(nb::cast(file), arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; save_safetensors(writer, arrays_map, metadata_map); } } else { @@ -481,22 +479,22 @@ void mlx_save_safetensor_helper( } void mlx_save_gguf_helper( - py::object file, - py::dict a, - std::optional m) { - auto arrays_map = a.cast>(); - if (py::isinstance(file)) { + nb::object file, + nb::dict a, + std::optional m) { + auto arrays_map = nb::cast>(a); + if (nb::isinstance(file)) { if (m) { auto metadata_map = - m.value().cast>(); + nb::cast>(m.value()); { - py::gil_scoped_release nogil; - save_gguf(py::cast(file), arrays_map, metadata_map); + nb::gil_scoped_release nogil; + save_gguf(nb::cast(file), arrays_map, metadata_map); } } else { { - py::gil_scoped_release nogil; - save_gguf(py::cast(file), arrays_map); + nb::gil_scoped_release nogil; + save_gguf(nb::cast(file), arrays_map); } } } else { diff --git a/python/src/load.h b/python/src/load.h index 21f0cff32..90ddeb8b3 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -1,15 +1,20 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once -#include +#include +#include +#include +#include +#include + #include #include #include #include #include "mlx/io.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlx::core; using LoadOutputTypes = std::variant< @@ -18,27 +23,27 @@ using LoadOutputTypes = std::variant< SafetensorsLoad, GGUFLoad>; -SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); +SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s); void mlx_save_safetensor_helper( - py::object file, - py::dict d, - std::optional m); + nb::object file, + nb::dict d, + std::optional m); -GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); +GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s); void mlx_save_gguf_helper( - py::object file, - py::dict d, - std::optional m); + nb::object file, + nb::dict d, + std::optional m); LoadOutputTypes mlx_load_helper( - py::object file, + nb::object file, std::optional format, bool return_metadata, StreamOrDevice s); -void mlx_save_helper(py::object file, array a); +void mlx_save_helper(nb::object file, array a); void mlx_savez_helper( - py::object file, - py::args args, - const py::kwargs& kwargs, + nb::object file, + nb::args args, + const nb::kwargs& kwargs, bool compressed = false); diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 4263f6bff..bec29d3aa 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -1,16 +1,15 @@ -// Copyright © 2023 Apple Inc. - -#include +// Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/metal.h" +#include -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; -void init_metal(py::module_& m) { - py::module_ metal = m.def_submodule("metal", "mlx.metal"); +void init_metal(nb::module_& m) { + nb::module_ metal = m.def_submodule("metal", "mlx.metal"); metal.def( "is_available", &metal::is_available, @@ -48,7 +47,7 @@ void init_metal(py::module_& m) { "set_memory_limit", &metal::set_memory_limit, "limit"_a, - py::kw_only(), + nb::kw_only(), "relaxed"_a = true, R"pbdoc( Set the memory limit. diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index c0f6e7dd4..30bf67fa5 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -1,30 +1,30 @@ -// Copyright © 2023 Apple Inc. +// Conbright © 2023-2024 Apple Inc. -#include +#include #define STRINGIFY(x) #x #define TOSTRING(x) STRINGIFY(x) -namespace py = pybind11; +namespace nb = nanobind; -void init_array(py::module_&); -void init_device(py::module_&); -void init_stream(py::module_&); -void init_metal(py::module_&); -void init_ops(py::module_&); -void init_transforms(py::module_&); -void init_random(py::module_&); -void init_fft(py::module_&); -void init_linalg(py::module_&); -void init_constants(py::module_&); -void init_extensions(py::module_&); -void init_utils(py::module_&); +void init_array(nb::module_&); +void init_device(nb::module_&); +void init_stream(nb::module_&); +void init_metal(nb::module_&); +void init_ops(nb::module_&); +void init_transforms(nb::module_&); +void init_random(nb::module_&); +void init_fft(nb::module_&); +void init_linalg(nb::module_&); +void init_constants(nb::module_&); +void init_fast(nb::module_&); -PYBIND11_MODULE(core, m) { +NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; - auto reprlib_fix = py::module_::import("mlx._reprlib_fix"); - py::module_::import("mlx._os_warning"); + auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix"); + nb::module_::import_("mlx._os_warning"); + nb::set_leak_warnings(false); init_device(m); init_stream(m); @@ -36,8 +36,7 @@ PYBIND11_MODULE(core, m) { init_fft(m); init_linalg(m); init_constants(m); - init_extensions(m); - init_utils(m); + init_fast(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 53eead872..977e5897f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1,20 +1,24 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" #include "python/src/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; using Scalar = std::variant; @@ -35,21 +39,17 @@ double scalar_to_double(Scalar s) { } } -void init_ops(py::module_& m) { - py::options options; - options.disable_function_signatures(); - +void init_ops(nb::module_& m) { m.def( "reshape", &reshape, - "a"_a, - py::pos_only(), + nb::arg(), "shape"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig("def reshape(a: array, /, shape: List[int], *, stream: " + "Union[None, Stream, Device] = None) -> array"), R"pbdoc( - reshape(a: array, /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array - Reshape an array while preserving the size. Args: @@ -67,15 +67,14 @@ void init_ops(py::module_& m) { int start_axis, int end_axis, const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); }, - "a"_a, - py::pos_only(), + nb::arg(), "start_axis"_a = 0, "end_axis"_a = -1, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig("def flatten(a: array, /, start_axis: int = 0, end_axis: int = " + "-1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array - Flatten an array. The axes flattened will be between ``start_axis`` and ``end_axis``, @@ -112,14 +111,13 @@ void init_ops(py::module_& m) { return squeeze(a, std::get>(v), s); } }, - "a"_a, - py::pos_only(), - "axis"_a = none, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig("def squeeze(a: array, /, axis: Union[None, int, List[int]] = " + "None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - squeeze(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array - Remove length one axes from an array. Args: @@ -141,14 +139,13 @@ void init_ops(py::module_& m) { return expand_dims(a, std::get>(v), s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "axis"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig("def expand_dims(a: array, /, axis: Union[int, List[int]], " + "*, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - expand_dims(a: array, /, axis: Union[int, List[int]], *, stream: Union[None, Stream, Device] = None) -> array - Add a size one dimension at the given axis. Args: @@ -161,13 +158,12 @@ void init_ops(py::module_& m) { m.def( "abs", &mlx::core::abs, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise absolute value. Args: @@ -179,13 +175,12 @@ void init_ops(py::module_& m) { m.def( "sign", &sign, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise sign. Args: @@ -197,13 +192,12 @@ void init_ops(py::module_& m) { m.def( "negative", &negative, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise negation. Args: @@ -218,14 +212,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return add(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise addition. Add two arrays with numpy-style broadcasting semantics. Either or both input arrays @@ -244,14 +237,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return subtract(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise subtraction. Subtract one array from another with numpy-style broadcasting semantics. Either or both @@ -270,14 +262,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return multiply(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise multiplication. Multiply two arrays with numpy-style broadcasting semantics. Either or both @@ -296,14 +287,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return divide(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise division. Divide two arrays with numpy-style broadcasting semantics. Either or both @@ -322,14 +312,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return divmod(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise quotient and remainder. The fuction ``divmod(a, b)`` is equivalent to but faster than @@ -349,14 +338,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return floor_divide(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise integer division. If either array is a floating point type then it is equivalent to @@ -375,14 +363,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return remainder(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise remainder of division. Computes the remainder of dividing a with b with numpy-style @@ -402,14 +389,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return equal(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise equality. Equality comparison on two arrays with numpy-style broadcasting semantics. @@ -428,14 +414,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return not_equal(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise not equal. Not equal comparison on two arrays with numpy-style broadcasting semantics. @@ -454,14 +439,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return less(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise less than. Strict less than on two arrays with numpy-style broadcasting semantics. @@ -480,14 +464,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return less_equal(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise less than or equal. Less than or equal on two arrays with numpy-style broadcasting semantics. @@ -506,14 +489,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return greater(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise greater than. Strict greater than on two arrays with numpy-style broadcasting semantics. @@ -532,14 +514,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return greater_equal(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array - Element-wise greater or equal. Greater than or equal on two arrays with numpy-style broadcasting semantics. @@ -561,15 +542,14 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return array_equal(a, b, equal_nan, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), + nb::arg(), + nb::arg(), + nb::kw_only(), "equal_nan"_a = false, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array - Array equality check. Compare two arrays for equality. Returns ``True`` if and only if the arrays @@ -588,14 +568,13 @@ void init_ops(py::module_& m) { m.def( "matmul", &matmul, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Matrix multiplication. Perform the (possibly batched) matrix multiplication of two arrays. This function supports @@ -621,13 +600,12 @@ void init_ops(py::module_& m) { m.def( "square", &square, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise square. Args: @@ -639,13 +617,12 @@ void init_ops(py::module_& m) { m.def( "sqrt", &mlx::core::sqrt, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise square root. Args: @@ -657,13 +634,12 @@ void init_ops(py::module_& m) { m.def( "rsqrt", &rsqrt, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise reciprocal and square root. Args: @@ -675,13 +651,12 @@ void init_ops(py::module_& m) { m.def( "reciprocal", &reciprocal, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise reciprocal. Args: @@ -695,13 +670,12 @@ void init_ops(py::module_& m) { [](const ScalarOrArray& a, StreamOrDevice s) { return logical_not(to_array(a), s); }, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise logical not. Args: @@ -715,14 +689,13 @@ void init_ops(py::module_& m) { [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { return logical_and(to_array(a), to_array(b), s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise logical and. Args: @@ -738,14 +711,13 @@ void init_ops(py::module_& m) { [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { return logical_or(to_array(a), to_array(b), s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise logical or. Args: @@ -761,14 +733,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return logaddexp(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise log-add-exp. This is a numerically stable log-add-exp of two arrays with numpy-style @@ -786,13 +757,12 @@ void init_ops(py::module_& m) { m.def( "exp", &mlx::core::exp, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise exponential. Args: @@ -804,13 +774,12 @@ void init_ops(py::module_& m) { m.def( "erf", &mlx::core::erf, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise error function. .. math:: @@ -825,13 +794,12 @@ void init_ops(py::module_& m) { m.def( "erfinv", &mlx::core::erfinv, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse of :func:`erf`. Args: @@ -843,13 +811,12 @@ void init_ops(py::module_& m) { m.def( "sin", &mlx::core::sin, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise sine. Args: @@ -861,13 +828,12 @@ void init_ops(py::module_& m) { m.def( "cos", &mlx::core::cos, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise cosine. Args: @@ -879,13 +845,12 @@ void init_ops(py::module_& m) { m.def( "tan", &mlx::core::tan, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise tangent. Args: @@ -897,13 +862,12 @@ void init_ops(py::module_& m) { m.def( "arcsin", &mlx::core::arcsin, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse sine. Args: @@ -915,13 +879,12 @@ void init_ops(py::module_& m) { m.def( "arccos", &mlx::core::arccos, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse cosine. Args: @@ -933,13 +896,12 @@ void init_ops(py::module_& m) { m.def( "arctan", &mlx::core::arctan, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse tangent. Args: @@ -951,13 +913,12 @@ void init_ops(py::module_& m) { m.def( "sinh", &mlx::core::sinh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise hyperbolic sine. Args: @@ -969,13 +930,12 @@ void init_ops(py::module_& m) { m.def( "cosh", &mlx::core::cosh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise hyperbolic cosine. Args: @@ -987,13 +947,12 @@ void init_ops(py::module_& m) { m.def( "tanh", &mlx::core::tanh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise hyperbolic tangent. Args: @@ -1005,13 +964,12 @@ void init_ops(py::module_& m) { m.def( "arcsinh", &mlx::core::arcsinh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse hyperbolic sine. Args: @@ -1023,13 +981,12 @@ void init_ops(py::module_& m) { m.def( "arccosh", &mlx::core::arccosh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse hyperbolic cosine. Args: @@ -1041,13 +998,12 @@ void init_ops(py::module_& m) { m.def( "arctanh", &mlx::core::arctanh, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise inverse hyperbolic tangent. Args: @@ -1059,13 +1015,12 @@ void init_ops(py::module_& m) { m.def( "log", &mlx::core::log, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise natural logarithm. Args: @@ -1077,13 +1032,12 @@ void init_ops(py::module_& m) { m.def( "log2", &mlx::core::log2, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise base-2 logarithm. Args: @@ -1095,13 +1049,12 @@ void init_ops(py::module_& m) { m.def( "log10", &mlx::core::log10, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise base-10 logarithm. Args: @@ -1113,13 +1066,12 @@ void init_ops(py::module_& m) { m.def( "log1p", &mlx::core::log1p, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise natural log of one plus the array. Args: @@ -1131,13 +1083,12 @@ void init_ops(py::module_& m) { m.def( "stop_gradient", &stop_gradient, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Stop gradients from being computed. The operation is the identity but it prevents gradients from flowing @@ -1153,13 +1104,12 @@ void init_ops(py::module_& m) { m.def( "sigmoid", &sigmoid, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise logistic sigmoid. The logistic sigmoid function is: @@ -1179,14 +1129,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return power(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise power operation. Raise the elements of a to the powers in elements of b with numpy-style @@ -1208,8 +1157,8 @@ void init_ops(py::module_& m) { return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); }, "stop"_a, - "dtype"_a = none, - "stream"_a = none); + "dtype"_a = nb::none(), + "stream"_a = nb::none()); m.def( "arange", [](Scalar start, @@ -1224,8 +1173,8 @@ void init_ops(py::module_& m) { }, "start"_a, "stop"_a, - "dtype"_a = none, - "stream"_a = none); + "dtype"_a = nb::none(), + "stream"_a = nb::none()); m.def( "arange", [](Scalar stop, @@ -1241,8 +1190,8 @@ void init_ops(py::module_& m) { }, "stop"_a, "step"_a, - "dtype"_a = none, - "stream"_a = none); + "dtype"_a = nb::none(), + "stream"_a = nb::none()); m.def( "arange", [](Scalar start, @@ -1267,11 +1216,11 @@ void init_ops(py::module_& m) { "start"_a, "stop"_a, "step"_a, - "dtype"_a = none, - "stream"_a = none, + "dtype"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array - Generates ranges of numbers. Generate numbers in the half-open interval ``[start, stop)`` in @@ -1312,22 +1261,22 @@ void init_ops(py::module_& m) { "start"_a, "stop"_a, "num"_a = 50, - "dtype"_a = std::optional{float32}, - "stream"_a = none, + "dtype"_a = float32, + "stream"_a = nb::none(), + nb::sig( + "def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array + Generate ``num`` evenly spaced numbers over interval ``[start, stop]``. - Generate ``num`` evenly spaced numbers over interval ``[start, stop]``. + Args: + start (scalar): Starting value. + stop (scalar): Stopping value. + num (int, optional): Number of samples, defaults to ``50``. + dtype (Dtype, optional): Specifies the data type of the output, + default to ``float32``. - Args: - start (scalar): Starting value. - stop (scalar): Stopping value. - num (int, optional): Number of samples, defaults to ``50``. - dtype (Dtype, optional): Specifies the data type of the output, - default to ``float32``. - - Returns: - array: The range of values. + Returns: + array: The range of values. )pbdoc"); m.def( "take", @@ -1341,15 +1290,14 @@ void init_ops(py::module_& m) { return take(a, indices, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "indices"_a, - "axis"_a = std::nullopt, - py::kw_only(), - "stream"_a = none, + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - take(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array - Take elements along an axis. The elements are taken from ``indices`` along the specified axis. @@ -1379,15 +1327,14 @@ void init_ops(py::module_& m) { return take_along_axis(reshape(a, {-1}, s), indices, 0, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "indices"_a, - "axis"_a, - py::kw_only(), - "stream"_a = none, + "axis"_a.none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array - Take values along an axis at the specified indices. Args: @@ -1416,12 +1363,12 @@ void init_ops(py::module_& m) { }, "shape"_a, "vals"_a, - "dtype"_a = std::nullopt, - py::kw_only(), - "stream"_a = none, + "dtype"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - full(shape: Union[int, List[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array - Construct an array with the given value. Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` @@ -1449,12 +1396,12 @@ void init_ops(py::module_& m) { } }, "shape"_a, - "dtype"_a = std::optional{float32}, - py::kw_only(), - "stream"_a = none, + "dtype"_a = float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array - Construct an array of zeros. Args: @@ -1468,13 +1415,12 @@ void init_ops(py::module_& m) { m.def( "zeros_like", &zeros_like, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - An array of zeros like the input. Args: @@ -1496,12 +1442,12 @@ void init_ops(py::module_& m) { } }, "shape"_a, - "dtype"_a = std::optional{float32}, - py::kw_only(), - "stream"_a = none, + "dtype"_a = float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array - Construct an array of ones. Args: @@ -1515,13 +1461,12 @@ void init_ops(py::module_& m) { m.def( "ones_like", &ones_like, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - An array of ones like the input. Args: @@ -1540,25 +1485,25 @@ void init_ops(py::module_& m) { return eye(n, m.value_or(n), k, dtype.value_or(float32), s); }, "n"_a, - "m"_a = py::none(), + "m"_a = nb::none(), "k"_a = 0, - "dtype"_a = std::optional{float32}, - py::kw_only(), - "stream"_a = none, + "dtype"_a = float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array + Create an identity matrix or a general diagonal matrix. - Create an identity matrix or a general diagonal matrix. + Args: + n (int): The number of rows in the output. + m (int, optional): The number of columns in the output. Defaults to n. + k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal). + dtype (Dtype, optional): Data type of the output array. Defaults to float32. + stream (Stream, optional): Stream or device. Defaults to None. - Args: - n (int): The number of rows in the output. - m (int, optional): The number of columns in the output. Defaults to n. - k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal). - dtype (Dtype, optional): Data type of the output array. Defaults to float32. - stream (Stream, optional): Stream or device. Defaults to None. - - Returns: - array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. + Returns: + array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. )pbdoc"); m.def( "identity", @@ -1566,21 +1511,21 @@ void init_ops(py::module_& m) { return identity(n, dtype.value_or(float32), s); }, "n"_a, - "dtype"_a = std::optional{float32}, - py::kw_only(), - "stream"_a = none, + "dtype"_a = float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array + Create a square identity matrix. - Create a square identity matrix. + Args: + n (int): The number of rows and columns in the output. + dtype (Dtype, optional): Data type of the output array. Defaults to float32. + stream (Stream, optional): Stream or device. Defaults to None. - Args: - n (int): The number of rows and columns in the output. - dtype (Dtype, optional): Data type of the output array. Defaults to float32. - stream (Stream, optional): Stream or device. Defaults to None. - - Returns: - array: An identity matrix of size n x n. + Returns: + array: An identity matrix of size n x n. )pbdoc"); m.def( "tri", @@ -1592,14 +1537,14 @@ void init_ops(py::module_& m) { return tri(n, m.value_or(n), k, type.value_or(float32), s); }, "n"_a, - "m"_a = none, + "m"_a = nb::none(), "k"_a = 0, - "dtype"_a = std::optional{float32}, - py::kw_only(), - "stream"_a = none, + "dtype"_a = float32, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array - An array with ones at and below the given diagonal and zeros elsewhere. Args: @@ -1617,31 +1562,31 @@ void init_ops(py::module_& m) { &tril, "x"_a, "k"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array + Zeros the array above the given diagonal. - Zeros the array above the given diagonal. + Args: + x (array): input array. + k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. + stream (Stream, optional): Stream or device. Defaults to ``None``. - Args: - x (array): input array. - k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. - stream (Stream, optional): Stream or device. Defaults to ``None``. - - Returns: - array: Array zeroed above the given diagonal - )pbdoc"); + Returns: + array: Array zeroed above the given diagonal + )pbdoc"); m.def( "triu", &triu, "x"_a, "k"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array - Zeros the array below the given diagonal. Args: @@ -1655,17 +1600,16 @@ void init_ops(py::module_& m) { m.def( "allclose", &allclose, - "a"_a, - "b"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "rtol"_a = 1e-5, "atol"_a = 1e-8, - py::kw_only(), + nb::kw_only(), "equal_nan"_a = false, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array - Approximate comparison of two arrays. Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``. @@ -1693,17 +1637,16 @@ void init_ops(py::module_& m) { m.def( "isclose", &isclose, - "a"_a, - "b"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "rtol"_a = 1e-5, "atol"_a = 1e-8, - py::kw_only(), + nb::kw_only(), "equal_nan"_a = false, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array - Returns a boolean array where two arrays are element-wise equal within a tolerance. Infinite values are considered equal if they have the same sign, NaN values are @@ -1737,15 +1680,14 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - all(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - An `and` reduction over the given axes. Args: @@ -1767,15 +1709,14 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - any(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - An `or` reduction over the given axes. Args: @@ -1795,14 +1736,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return minimum(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise minimum. Take the element-wise min of two arrays with numpy-style broadcasting @@ -1821,14 +1761,13 @@ void init_ops(py::module_& m) { auto [a, b] = to_arrays(a_, b_); return maximum(a, b, s); }, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise maximum. Take the element-wise max of two arrays with numpy-style broadcasting @@ -1844,13 +1783,12 @@ void init_ops(py::module_& m) { m.def( "floor", &mlx::core::floor, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise floor. Args: @@ -1862,13 +1800,12 @@ void init_ops(py::module_& m) { m.def( "ceil", &mlx::core::ceil, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Element-wise ceil. Args: @@ -1880,13 +1817,12 @@ void init_ops(py::module_& m) { m.def( "isnan", &mlx::core::isnan, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def isnan(a: array, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - isnan(a: array, stream: Union[None, Stream, Device] = None) -> array - Return a boolean array indicating which elements are NaN. Args: @@ -1898,13 +1834,12 @@ void init_ops(py::module_& m) { m.def( "isinf", &mlx::core::isinf, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def isinf(a: array, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - isinf(a: array, stream: Union[None, Stream, Device] = None) -> array - Return a boolean array indicating which elements are +/- inifnity. Args: @@ -1916,13 +1851,12 @@ void init_ops(py::module_& m) { m.def( "isposinf", &isposinf, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array - Return a boolean array indicating which elements are positive infinity. Args: @@ -1935,13 +1869,12 @@ void init_ops(py::module_& m) { m.def( "isneginf", &isneginf, - "a"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array - Return a boolean array indicating which elements are negative infinity. Args: @@ -1954,15 +1887,14 @@ void init_ops(py::module_& m) { m.def( "moveaxis", &moveaxis, - "a"_a, - py::pos_only(), + nb::arg(), "source"_a, "destination"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array - Move an axis to a new position. Args: @@ -1976,15 +1908,14 @@ void init_ops(py::module_& m) { m.def( "swapaxes", &swapaxes, - "a"_a, - py::pos_only(), + nb::arg(), "axis1"_a, "axis2"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array - Swap two axes of an array. Args: @@ -2006,14 +1937,13 @@ void init_ops(py::module_& m) { return transpose(a, s); } }, - "a"_a, - py::pos_only(), - "axes"_a = std::nullopt, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axes"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - transpose(a: array, /, axes: Optional[List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array - Transpose the dimensions of the array. Args: @@ -2033,14 +1963,13 @@ void init_ops(py::module_& m) { return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "array"_a, - py::pos_only(), - "axis"_a = none, + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sum(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - Sum reduce the array over the given axes. Args: @@ -2062,15 +1991,14 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - prod(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - An product reduction over the given axes. Args: @@ -2092,16 +2020,15 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - min(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - - An `min` reduction over the given axes. + A `min` reduction over the given axes. Args: a (array): Input array. @@ -2122,16 +2049,15 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - max(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - - An `max` reduction over the given axes. + A `max` reduction over the given axes. Args: a (array): Input array. @@ -2152,15 +2078,14 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - logsumexp(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - A `log-sum-exp` reduction over the given axes. The log-sum-exp reduction is a numerically stable version of: @@ -2188,15 +2113,14 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - mean(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - Compute the mean(s) over the given axes. Args: @@ -2219,16 +2143,15 @@ void init_ops(py::module_& m) { StreamOrDevice s) { return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, "ddof"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - var(a: array, /, axis: Union[None, int, List[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array - Compute the variance(s) over the given axes. Args: @@ -2257,15 +2180,14 @@ void init_ops(py::module_& m) { a, std::get>(indices_or_sections), axis, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "indices_or_sections"_a, "axis"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - split(a: array, /, indices_or_sections: Union[int, List[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array - Split an array along a given axis. Args: @@ -2292,15 +2214,14 @@ void init_ops(py::module_& m) { return argmin(a, keepdims, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - Indices of the minimum values along the axis. Args: @@ -2325,15 +2246,14 @@ void init_ops(py::module_& m) { return argmax(a, keepdims, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, + nb::arg(), + "axis"_a = nb::none(), "keepdims"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - Indices of the maximum values along the axis. Args: @@ -2355,14 +2275,13 @@ void init_ops(py::module_& m) { return sort(a, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = -1, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axis"_a.none() = -1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - Returns a sorted copy of the array. Args: @@ -2383,14 +2302,13 @@ void init_ops(py::module_& m) { return argsort(a, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = -1, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axis"_a.none() = -1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - Returns the indices that sort the array. Args: @@ -2411,15 +2329,14 @@ void init_ops(py::module_& m) { return partition(a, kth, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "kth"_a, - "axis"_a = -1, - py::kw_only(), - "stream"_a = none, + "axis"_a.none() = -1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - Returns a partitioned copy of the array such that the smaller ``kth`` elements are first. @@ -2447,15 +2364,14 @@ void init_ops(py::module_& m) { return argpartition(a, kth, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "kth"_a, "axis"_a = -1, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - Returns the indices that partition the array. The ordering of the elements within a partition in given by the indices @@ -2484,15 +2400,14 @@ void init_ops(py::module_& m) { return topk(a, k, s); } }, - "a"_a, - py::pos_only(), + nb::arg(), "k"_a, - "axis"_a = -1, - py::kw_only(), - "stream"_a = none, + "axis"_a.none() = -1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array - Returns the ``k`` largest elements from the input along a given axis. The elements will not necessarily be in sorted order. @@ -2512,14 +2427,13 @@ void init_ops(py::module_& m) { [](const ScalarOrArray& a, const std::vector& shape, StreamOrDevice s) { return broadcast_to(to_array(a), shape, s); }, - "a"_a, - py::pos_only(), + nb::arg(), "shape"_a, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - broadcast_to(a: Union[scalar, array], /, shape: List[int], *, stream: Union[None, Stream, Device] = None) -> array - Broadcast an array to the given shape. The broadcasting semantics are the same as Numpy. @@ -2536,14 +2450,13 @@ void init_ops(py::module_& m) { [](const array& a, const IntOrVec& axis, StreamOrDevice s) { return softmax(a, get_reduce_axes(axis, a.ndim()), s); }, - "a"_a, - py::pos_only(), - "axis"_a = none, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - softmax(a: array, /, axis: Union[None, int, List[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array - Perform the softmax along the given axis. This operation is a numerically stable version of: @@ -2572,14 +2485,13 @@ void init_ops(py::module_& m) { return concatenate(arrays, s); } }, - "arrays"_a, - py::pos_only(), - "axis"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "axis"_a.none() = 0, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array - Concatenate the arrays along the given axis. Args: @@ -2601,25 +2513,24 @@ void init_ops(py::module_& m) { return stack(arrays, s); } }, - "arrays"_a, - py::pos_only(), + nb::arg(), "axis"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array + Stacks the arrays along a new axis. - Stacks the arrays along a new axis. + Args: + arrays (list(array)): A list of arrays to stack. + axis (int, optional): The axis in the result array along which the + input arrays are stacked. Defaults to ``0``. + stream (Stream, optional): Stream or device. Defaults to ``None``. - Args: - arrays (list(array)): A list of arrays to stack. - axis (int, optional): The axis in the result array along which the - input arrays are stacked. Defaults to ``0``. - stream (Stream, optional): Stream or device. Defaults to ``None``. - - Returns: - array: The resulting stacked array. - )pbdoc"); + Returns: + array: The resulting stacked array. + )pbdoc"); m.def( "repeat", [](const array& array, @@ -2632,28 +2543,27 @@ void init_ops(py::module_& m) { return repeat(array, repeats, s); } }, - "array"_a, - py::pos_only(), + nb::arg(), "repeats"_a, - "axis"_a = none, - py::kw_only(), - "stream"_a = none, + "axis"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array + Repeat an array along a specified axis. - Repeat an array along a specified axis. + Args: + array (array): Input array. + repeats (int): The number of repetitions for each element. + axis (int, optional): The axis in which to repeat the array along. If + unspecified it uses the flattened array of the input and repeats + along axis 0. + stream (Stream, optional): Stream or device. Defaults to ``None``. - Args: - array (array): Input array. - repeats (int): The number of repetitions for each element. - axis (int, optional): The axis in which to repeat the array along. If - unspecified it uses the flattened array of the input and repeats - along axis 0. - stream (Stream, optional): Stream or device. Defaults to ``None``. - - Returns: - array: The resulting repeated array. - )pbdoc"); + Returns: + array: The resulting repeated array. + )pbdoc"); m.def( "clip", [](const array& a, @@ -2670,29 +2580,28 @@ void init_ops(py::module_& m) { } return clip(a, min_, max_, s); }, - "a"_a, - py::pos_only(), - "a_min"_a, - "a_max"_a, - py::kw_only(), - "stream"_a = none, + nb::arg(), + "a_min"_a.none(), + "a_max"_a.none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array + Clip the values of the array between the given minimum and maximum. - Clip the values of the array between the given minimum and maximum. + If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge + is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``. + The input ``a`` and the limits must broadcast with one another. - If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge - is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``. - The input ``a`` and the limits must broadcast with one another. + Args: + a (array): Input array. + a_min (scalar or array or None): Minimum value to clip to. + a_max (scalar or array or None): Maximum value to clip to. - Args: - a (array): Input array. - a_min (scalar or array or None): Minimum value to clip to. - a_max (scalar or array or None): Maximum value to clip to. - - Returns: - array: The clipped array. - )pbdoc"); + Returns: + array: The clipped array. + )pbdoc"); m.def( "pad", [](const array& a, @@ -2718,15 +2627,14 @@ void init_ops(py::module_& m) { } } }, - "a"_a, - py::pos_only(), + nb::arg(), "pad_width"_a, "constant_values"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - pad(a: array, pad_with: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array - Pad an array with a constant value Args: @@ -2762,16 +2670,15 @@ void init_ops(py::module_& m) { } return as_strided(a, a_shape, a_strides, offset, s); }, - "a"_a, - py::pos_only(), - "shape"_a = none, - "strides"_a = none, + nb::arg(), + "shape"_a = nb::none(), + "strides"_a = nb::none(), "offset"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - as_strided(a: array, /, shape: Optional[List[int]] = None, strides: Optional[List[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array - Create a view into the array with the given shape and strides. The resulting array will always be as if the provided array was row @@ -2810,16 +2717,15 @@ void init_ops(py::module_& m) { return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, - py::kw_only(), + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array - Return the cumulative sum of the elements along the given axis. Args: @@ -2844,16 +2750,15 @@ void init_ops(py::module_& m) { return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, - py::kw_only(), + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array - Return the cumulative product of the elements along the given axis. Args: @@ -2878,16 +2783,15 @@ void init_ops(py::module_& m) { return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, - py::kw_only(), + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array - Return the cumulative maximum of the elements along the given axis. Args: @@ -2912,16 +2816,15 @@ void init_ops(py::module_& m) { return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, - "a"_a, - py::pos_only(), - "axis"_a = std::nullopt, - py::kw_only(), + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array - Return the cumulative minimum of the elements along the given axis. Args: @@ -2985,15 +2888,14 @@ void init_ops(py::module_& m) { return reshape(out, {-1}, s); }, - "a"_a, - "v"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "mode"_a = "full", - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + R"(def convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array)"), R"pbdoc( - convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array - The discrete convolution of 1D arrays. If ``v`` is longer than ``a``, then they are swapped. @@ -3010,18 +2912,17 @@ void init_ops(py::module_& m) { m.def( "conv1d", &conv1d, - "input"_a, - "weight"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, "groups"_a = 1, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array - 1D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. @@ -3071,18 +2972,17 @@ void init_ops(py::module_& m) { return conv2d( input, weight, stride_pair, padding_pair, dilation_pair, groups, s); }, - "input"_a, - "weight"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, "groups"_a = 1, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array - 2D convolution over an input with several channels Note: Only the default ``groups=1`` is currently supported. @@ -3167,20 +3067,19 @@ void init_ops(py::module_& m) { /* bool flip = */ flip, s); }, - "input"_a, - "weight"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "stride"_a = 1, "padding"_a = 0, "kernel_dilation"_a = 1, "input_dilation"_a = 1, "groups"_a = 1, "flip"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array - General convolution over an input with several channels .. note:: @@ -3217,9 +3116,8 @@ void init_ops(py::module_& m) { &mlx_save_helper, "file"_a, "arr"_a, + nb::sig("def save(file: str, arr: array) -> None"), R"pbdoc( - save(file: str, arr: array) - Save the array to a binary file in ``.npy`` format. Args: @@ -3228,16 +3126,15 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "savez", - [](py::object file, py::args args, const py::kwargs& kwargs) { - mlx_savez_helper(file, args, kwargs, /*compressed=*/false); + [](nb::object file, nb::args args, const nb::kwargs& kwargs) { + mlx_savez_helper(file, args, kwargs, /* compressed= */ false); }, "file"_a, - py::pos_only(), - py::kw_only(), + "args"_a, + "kwargs"_a, R"pbdoc( - savez(file: str, *args, **kwargs) - - Save several arrays to a binary file in uncompressed ``.npz`` format. + Save several arrays to a binary file in uncompressed ``.npz`` + format. .. code-block:: python @@ -3258,19 +3155,17 @@ void init_ops(py::module_& m) { args (arrays): Arrays to be saved. kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. - )pbdoc"); m.def( "savez_compressed", - [](py::object file, py::args args, const py::kwargs& kwargs) { + [](nb::object file, nb::args args, const nb::kwargs& kwargs) { mlx_savez_helper(file, args, kwargs, /*compressed=*/true); }, - "file"_a, - py::pos_only(), - py::kw_only(), + nb::arg(), + "args"_a, + "kwargs"_a, + nb::sig("def savez_compressed(file: str, *args, **kwargs)"), R"pbdoc( - savez_compressed(file: str, *args, **kwargs) - Save several arrays to a binary file in compressed ``.npz`` format. Args: @@ -3278,84 +3173,89 @@ void init_ops(py::module_& m) { args (arrays): Arrays to be saved. kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. - )pbdoc"); m.def( "load", &mlx_load_helper, - "file"_a, - py::pos_only(), - "format"_a = none, + nb::arg(), + "format"_a = nb::none(), "return_metadata"_a = false, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]"), R"pbdoc( - load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] - Load array(s) from a binary file. - The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and ``.gguf``. + The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and + ``.gguf``. Args: file (file, str): File in which the array is saved. - format (str, optional): Format of the file. If ``None``, the format - is inferred from the file extension. Supported formats: ``npy``, + format (str, optional): Format of the file. If ``None``, the + format + is inferred from the file extension. Supported formats: + ``npy``, ``npz``, and ``safetensors``. Default: ``None``. - return_metadata (bool, optional): Load the metadata for formats which - support matadata. The metadata will be returned as an additional - dictionary. + return_metadata (bool, optional): Load the metadata for formats + which + support matadata. The metadata will be returned as an + additional dictionary. Returns: result (array, dict): - A single array if loading from a ``.npy`` file or a dict mapping - names to arrays if loading from a ``.npz`` or ``.safetensors`` file. - If ``return_metadata` is ``True`` an additional dictionary of metadata - will be returned. + A single array if loading from a ``.npy`` file or a dict + mapping names to arrays if loading from a ``.npz`` or + ``.safetensors`` file. If ``return_metadata` is ``True`` an + additional dictionary of metadata will be returned. Warning: - When loading unsupported quantization formats from GGUF, tensors will - automatically cast to ``mx.float16`` - + When loading unsupported quantization formats from GGUF, tensors + will automatically cast to ``mx.float16`` )pbdoc"); m.def( "save_safetensors", &mlx_save_safetensor_helper, "file"_a, "arrays"_a, - "metadata"_a = none, + "metadata"_a = nb::none(), + nb::sig( + "def save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)"), R"pbdoc( - save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None) - Save array(s) to a binary file in ``.safetensors`` format. - See the `Safetensors documentation `_ - for more information on the format. + See the `Safetensors documentation + `_ for more + information on the format. Args: file (file, str): File in which the array is saved. - arrays (dict(str, array)): The dictionary of names to arrays to be saved. - metadata (dict(str, str), optional): The dictionary of metadata to be saved. + arrays (dict(str, array)): The dictionary of names to arrays to + be saved. metadata (dict(str, str), optional): The dictionary of + metadata to be saved. )pbdoc"); m.def( "save_gguf", &mlx_save_gguf_helper, "file"_a, "arrays"_a, - "metadata"_a = none, + "metadata"_a = nb::none(), + nb::sig( + "def save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])"), R"pbdoc( - save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]]) - Save array(s) to a binary file in ``.gguf`` format. - See the `GGUF documentation `_ for + See the `GGUF documentation + `_ for more information on the format. Args: file (file, str): File in which the array is saved. - arrays (dict(str, array)): The dictionary of names to arrays to be saved. - metadata (dict(str, Union[array, str, list(str)])): The dictionary of - metadata to be saved. The values can be a scalar or 1D obj:`array`, - a :obj:`str`, or a :obj:`list` of :obj:`str`. + arrays (dict(str, array)): The dictionary of names to arrays to + be saved. metadata (dict(str, Union[array, str, list(str)])): + The dictionary of + metadata to be saved. The values can be a scalar or 1D + obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`. )pbdoc"); m.def( "where", @@ -3367,18 +3267,17 @@ void init_ops(py::module_& m) { return where(to_array(condition), x, y, s); }, "condition"_a, - "x"_a, - "y"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array - Select from ``x`` or ``y`` according to ``condition``. - The condition and input arrays must be the same shape or broadcastable - with each another. + The condition and input arrays must be the same shape or + broadcastable with each another. Args: condition (array): The condition array. @@ -3386,21 +3285,21 @@ void init_ops(py::module_& m) { y (array): The input selected from where condition is ``False``. Returns: - result (array): The output containing elements selected from ``x`` and ``y``. + result (array): The output containing elements selected from + ``x`` and ``y``. )pbdoc"); m.def( "round", [](const array& a, int decimals, StreamOrDevice s) { return round(a, decimals, s); }, - "a"_a, - py::pos_only(), + nb::arg(), "decimals"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array - Round to the given number of decimals. Basically performs: @@ -3415,24 +3314,24 @@ void init_ops(py::module_& m) { decimals (int): Number of decimal places to round to. (default: 0) Returns: - result (array): An array of the same type as ``a`` rounded to the given number of decimals. + result (array): An array of the same type as ``a`` rounded to the + given number of decimals. )pbdoc"); m.def( "quantized_matmul", &quantized_matmul, - "x"_a, - "w"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "scales"_a, "biases"_a, "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array - Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of elements. Each element in ``w`` takes ``bits`` bits and is packed in an @@ -3457,15 +3356,14 @@ void init_ops(py::module_& m) { m.def( "quantize", &quantize, - "w"_a, - py::pos_only(), + nb::arg(), "group_size"_a = 64, "bits"_a = 4, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), R"pbdoc( - quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array] - Quantize the matrix ``w`` using ``bits`` bits per element. Note, every ``group_size`` elements in a row of ``w`` are quantized @@ -3517,17 +3415,16 @@ void init_ops(py::module_& m) { m.def( "dequantize", &dequantize, - "w"_a, - py::pos_only(), + nb::arg(), "scales"_a, "biases"_a, "group_size"_a = 64, "bits"_a = 4, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array - Dequantize the matrix ``w`` using the provided ``scales`` and ``biases`` and the ``group_size`` and ``bits`` configuration. @@ -3568,15 +3465,14 @@ void init_ops(py::module_& m) { return tensordot(a, b, x[0], x[1], s); } }, - "a"_a, - "b"_a, - py::pos_only(), + nb::arg(), + nb::arg(), "axes"_a = 2, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array - Compute the tensor dot product along the specified axes. Args: @@ -3594,14 +3490,13 @@ void init_ops(py::module_& m) { m.def( "inner", &inner, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. Args: @@ -3614,14 +3509,13 @@ void init_ops(py::module_& m) { m.def( "outer", &outer, - "a"_a, - "b"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array - Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand. Args: @@ -3640,14 +3534,13 @@ void init_ops(py::module_& m) { return tile(a, std::get>(reps), s); } }, - "a"_a, - "reps"_a, - py::pos_only(), - py::kw_only(), - "stream"_a = none, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array - Construct an array by repeating ``a`` the number of times given by ``reps``. Args: @@ -3660,17 +3553,16 @@ void init_ops(py::module_& m) { m.def( "addmm", &addmm, - "c"_a, - "a"_a, - "b"_a, - py::pos_only(), + nb::arg(), + nb::arg(), + nb::arg(), "alpha"_a = 1.0f, "beta"_a = 1.0f, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array - Matrix multiplication with addition and optional scaling. Perform the (possibly batched) matrix multiplication of two arrays and add to the result @@ -3694,10 +3586,10 @@ void init_ops(py::module_& m) { "offset"_a = 0, "axis1"_a = 0, "axis2"_a = 1, - "stream"_a = none, + "stream"_a = nb::none(), + nb::sig( + "def diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array - Return specified diagonals. If ``a`` is 2-D, then a 1-D array containing the diagonal at the given @@ -3723,14 +3615,13 @@ void init_ops(py::module_& m) { m.def( "diag", &diag, - "a"_a, - py::pos_only(), + nb::arg(), "k"_a = 0, - py::kw_only(), - "stream"_a = none, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array - Extract a diagonal or construct a diagonal matrix. If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the :math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is @@ -3746,17 +3637,17 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_1d", - [](const py::args& arys, StreamOrDevice s) -> py::object { + [](const nb::args& arys, StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return py::cast(atleast_1d(arys[0].cast(), s)); + return nb::cast(atleast_1d(nb::cast(arys[0]), s)); } - return py::cast(atleast_1d(arys.cast>(), s)); + return nb::cast(atleast_1d(nb::cast>(arys), s)); }, - py::kw_only(), - "stream"_a = none, + "arys"_a, + "stream"_a = nb::none(), + nb::sig( + "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), R"pbdoc( - atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert all arrays to have at least one dimension. Args: @@ -3768,17 +3659,17 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_2d", - [](const py::args& arys, StreamOrDevice s) -> py::object { + [](const nb::args& arys, StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return py::cast(atleast_2d(arys[0].cast(), s)); + return nb::cast(atleast_2d(nb::cast(arys[0]), s)); } - return py::cast(atleast_2d(arys.cast>(), s)); + return nb::cast(atleast_2d(nb::cast>(arys), s)); }, - py::kw_only(), - "stream"_a = none, + "arys"_a, + "stream"_a = nb::none(), + nb::sig( + "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), R"pbdoc( - atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert all arrays to have at least two dimensions. Args: @@ -3788,20 +3679,19 @@ void init_ops(py::module_& m) { Returns: array or list(array): An array or list of arrays with at least two dimensions. )pbdoc"); - m.def( "atleast_3d", - [](const py::args& arys, StreamOrDevice s) -> py::object { + [](const nb::args& arys, StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return py::cast(atleast_3d(arys[0].cast(), s)); + return nb::cast(atleast_3d(nb::cast(arys[0]), s)); } - return py::cast(atleast_3d(arys.cast>(), s)); + return nb::cast(atleast_3d(nb::cast>(arys), s)); }, - py::kw_only(), - "stream"_a = none, + "arys"_a, + "stream"_a = nb::none(), + nb::sig( + "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"), R"pbdoc( - atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert all arrays to have at least three dimensions. Args: diff --git a/python/src/pybind11_numpy_fp16.h b/python/src/pybind11_numpy_fp16.h deleted file mode 100644 index ed496524f..000000000 --- a/python/src/pybind11_numpy_fp16.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. -#pragma once - -// A patch to get float16_t to work with pybind11 numpy arrays -// Derived from: -// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679 - -#include - -namespace pybind11::detail { - -template -struct npy_scalar_caster { - PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); - using Array = array_t; - - bool load(handle src, bool convert) { - // Taken from Eigen casters. Permits either scalar dtype or scalar array. - handle type = dtype::of().attr("type"); // Could make more efficient. - if (!convert && !isinstance(src) && !isinstance(src, type)) - return false; - Array tmp = Array::ensure(src); - if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { - this->value = *tmp.data(); - return true; - } - return false; - } - - static handle cast(T src, return_value_policy, handle) { - Array tmp({1}); - tmp.mutable_at(0) = src; - tmp.resize({}); - // You could also just return the array if you want a scalar array. - object scalar = tmp[tuple()]; - return scalar.release(); - } -}; - -// Similar to enums in `pybind11/numpy.h`. Determined by doing: -// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' -constexpr int NPY_FLOAT16 = 23; - -// Kinda following: -// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000 -template <> -struct npy_format_descriptor { - static constexpr auto name = _("float16"); - static pybind11::dtype dtype() { - handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); - return reinterpret_borrow(ptr); - } -}; - -template <> -struct type_caster : npy_scalar_caster { - static constexpr auto name = _("float16"); -}; - -} // namespace pybind11::detail diff --git a/python/src/random.cpp b/python/src/random.cpp index 442d81fee..b3cb3aa2f 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -1,7 +1,10 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. + +#include +#include +#include +#include -#include -#include #include #include "python/src/utils.h" @@ -9,8 +12,8 @@ #include "mlx/ops.h" #include "mlx/random.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; using namespace mlx::core::random; @@ -25,22 +28,22 @@ class PyKeySequence { } array next() { - auto out = split(py::cast(state_[0])); + auto out = split(nb::cast(state_[0])); state_[0] = out.first; return out.second; } - py::list state() { + nb::list state() { return state_; } void release() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; state_.release().dec_ref(); } private: - py::list state_; + nb::list state_; }; PyKeySequence& default_key() { @@ -54,7 +57,7 @@ PyKeySequence& default_key() { return ks; } -void init_random(py::module_& parent_module) { +void init_random(nb::module_& parent_module) { auto m = parent_module.def_submodule( "random", "mlx.core.random: functionality related to random number generation"); @@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) { )pbdoc"); m.def( "split", - py::overload_cast(&random::split), + nb::overload_cast(&random::split), "key"_a, "num"_a = 2, - "stream"_a = none, + "stream"_a = nb::none(), R"pbdoc( Split a PRNG key into sub keys. @@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) { "low"_a = 0, "high"_a = 1, "shape"_a = std::vector{}, - "dtype"_a = std::optional{float32}, - "key"_a = none, - "stream"_a = none, + "dtype"_a.none() = float32, + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Generate uniformly distributed random numbers. @@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) { return normal(shape, type.value_or(float32), loc, scale, key, s); }, "shape"_a = std::vector{}, - "dtype"_a = std::optional{float32}, + "dtype"_a.none() = float32, "loc"_a = 0.0, "scale"_a = 1.0, - "key"_a = none, - "stream"_a = none, + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Generate normally distributed random numbers. @@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) { "low"_a, "high"_a, "shape"_a = std::vector{}, - "dtype"_a = int32, - "key"_a = none, - "stream"_a = none, + "dtype"_a.none() = int32, + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Generate random integers from the given interval. @@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) { } }, "p"_a = 0.5, - "shape"_a = none, - "key"_a = none, - "stream"_a = none, + "shape"_a = nb::none(), + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Generate Bernoulli random values. @@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) { }, "lower"_a, "upper"_a, - "shape"_a = none, - "dtype"_a = std::optional{float32}, - "key"_a = none, - "stream"_a = none, + "shape"_a = nb::none(), + "dtype"_a.none() = float32, + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Generate values from a truncated normal distribution. @@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) { return gumbel(shape, type.value_or(float32), key, s); }, "shape"_a = std::vector{}, - "dtype"_a = std::optional{float32}, - "stream"_a = none, - "key"_a = none, + "dtype"_a.none() = float32, + "stream"_a = nb::none(), + "key"_a = nb::none(), R"pbdoc( Sample from the standard Gumbel distribution. @@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) { }, "logits"_a, "axis"_a = -1, - "shape"_a = none, - "num_samples"_a = none, - "key"_a = none, - "stream"_a = none, + "shape"_a = nb::none(), + "num_samples"_a = nb::none(), + "key"_a = nb::none(), + "stream"_a = nb::none(), R"pbdoc( Sample from a categorical distribution. @@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) { array: The ``shape``-sized output array with type ``uint32``. )pbdoc"); // Register static Python object cleanup before the interpreter exits - auto atexit = py::module_::import("atexit"); - atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); + auto atexit = nb::module_::import_("atexit"); + atexit.attr("register")(nb::cpp_function([]() { default_key().release(); })); } diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 768795fc1..c83d9b447 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -1,25 +1,54 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include -#include +#include +#include +#include #include "mlx/stream.h" #include "mlx/utils.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; -void init_stream(py::module_& m) { - py::class_( +// Create the StreamContext on enter and delete on exit. +class PyStreamContext { + public: + PyStreamContext(StreamOrDevice s) : _inner(nullptr) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + _s = s; + } + + void enter() { + _inner = new StreamContext(_s); + } + + void exit() { + if (_inner != nullptr) { + delete _inner; + _inner = nullptr; + } + } + + private: + StreamOrDevice _s; + StreamContext* _inner; +}; + +void init_stream(nb::module_& m) { + nb::class_( m, "Stream", R"pbdoc( A stream for running operations on a given device. )pbdoc") - .def(py::init(), "index"_a, "device"_a) - .def_readonly("device", &Stream::device) + .def(nb::init(), "index"_a, "device"_a) + .def_ro("device", &Stream::device) .def( "__repr__", [](const Stream& s) { @@ -31,7 +60,7 @@ void init_stream(py::module_& m) { return s1 == s2; }); - py::implicitly_convertible(); + nb::implicitly_convertible(); m.def( "default_stream", @@ -56,4 +85,48 @@ void init_stream(py::module_& m) { &new_stream, "device"_a, R"pbdoc(Make a new stream on the given device.)pbdoc"); + + nb::class_(m, "StreamContext", R"pbdoc( + A context manager for setting the current device and stream. + + See :func:`stream` for usage. + + Args: + s: The stream or device to set as the default. + )pbdoc") + .def(nb::init(), "s"_a) + .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) + .def( + "__exit__", + [](PyStreamContext& scm, + const std::optional& exc_type, + const std::optional& exc_value, + const std::optional& traceback) { scm.exit(); }, + "exc_type"_a = nb::none(), + "exc_value"_a = nb::none(), + "traceback"_a = nb::none()); + m.def( + "stream", + [](StreamOrDevice s) { return PyStreamContext(s); }, + "s"_a, + R"pbdoc( + Create a context manager to set the default device and stream. + + Args: + s: The :obj:`Stream` or :obj:`Device` to set as the default. + + Returns: + A context manager that sets the default device and stream. + + Example: + + .. code-block::python + + import mlx.core as mx + + # Create a context manager for the default device and stream. + with mx.stream(mx.cpu): + # Operations here will use mx.cpu by default. + pass + )pbdoc"); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 1612f774d..215c441de 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,6 +1,11 @@ // Copyright © 2023-2024 Apple Inc. -#include -#include +#include +#include +#include +#include +#include +#include + #include #include #include @@ -13,13 +18,17 @@ #include "mlx/transforms_impl.h" #include "python/src/trees.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlx::core; using IntOrVec = std::variant>; using StrOrVec = std::variant>; +inline std::string type_name_str(const nb::handle& o) { + return nb::cast(nb::type_name(o.type())); +} + template std::vector to_vector(const std::variant>& v) { std::vector vals; @@ -49,7 +58,7 @@ auto validate_argnums_argnames( } auto py_value_and_grad( - const py::function& fun, + const nb::callable& fun, std::vector argnums, std::vector argnames, const std::string& error_msg_tag, @@ -71,7 +80,7 @@ auto py_value_and_grad( } return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( - const py::args& args, const py::kwargs& kwargs) { + const nb::args& args, const nb::kwargs& kwargs) { // Sanitize the input if (argnums.size() > 0 && argnums.back() >= args.size()) { std::ostringstream msg; @@ -89,7 +98,7 @@ auto py_value_and_grad( << "' because the function is called with the " << "following keyword arguments {"; for (auto item : kwargs) { - msg << item.first.cast() << ","; + msg << nb::cast(item.first) << ","; } msg << "}"; throw std::invalid_argument(msg.str()); @@ -115,7 +124,7 @@ auto py_value_and_grad( // value_out will hold the output of the python function in order to be // able to reconstruct the python tree of extra return values - py::object py_value_out; + nb::object py_value_out; auto value_and_grads = value_and_grad( [&fun, &args, @@ -127,15 +136,15 @@ auto py_value_and_grad( &error_msg_tag, scalar_func_only](const std::vector& a) { // Copy the arguments - py::args args_cpy = py::tuple(args.size()); - py::kwargs kwargs_cpy = py::kwargs(); + nb::list args_cpy; + nb::kwargs kwargs_cpy = nb::kwargs(); int j = 0; for (int i = 0; i < args.size(); ++i) { if (j < argnums.size() && i == argnums[j]) { - args_cpy[i] = tree_unflatten(args[i], a, counts[j]); + args_cpy.append(tree_unflatten(args[i], a, counts[j])); j++; } else { - args_cpy[i] = args[i]; + args_cpy.append(args[i]); } } for (auto& key : argnames) { @@ -154,25 +163,25 @@ auto py_value_and_grad( py_value_out = fun(*args_cpy, **kwargs_cpy); // Validate the return value of the python function - if (!py::isinstance(py_value_out)) { + if (!nb::isinstance(py_value_out)) { if (scalar_func_only) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be a " - << "scalar array; but " << py_value_out.get_type() + << "scalar array; but " << type_name_str(py_value_out) << " was returned."; throw std::invalid_argument(msg.str()); } - if (!py::isinstance(py_value_out)) { + if (!nb::isinstance(py_value_out)) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be either a " << "scalar array or a tuple with the first value being a " << "scalar array (Union[array, Tuple[array, Any, ...]]); but " - << py_value_out.get_type() << " was returned."; + << type_name_str(py_value_out) << " was returned."; throw std::invalid_argument(msg.str()); } - py::tuple ret = py::cast(py_value_out); + nb::tuple ret = nb::cast(py_value_out); if (ret.size() == 0) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " @@ -182,14 +191,14 @@ auto py_value_and_grad( << "we got an empty tuple."; throw std::invalid_argument(msg.str()); } - if (!py::isinstance(ret[0])) { + if (!nb::isinstance(ret[0])) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be either a " << "scalar array or a tuple with the first value being a " << "scalar array (Union[array, Tuple[array, Any, ...]]); but it " << "was a tuple with the first value being of type " - << ret[0].get_type() << " ."; + << type_name_str(ret[0]) << " ."; throw std::invalid_argument(msg.str()); } } @@ -212,61 +221,60 @@ auto py_value_and_grad( // In case 2 we return a tuple of the above. // In case 3 we return a tuple containing a tuple and dict (sth like // (tuple(), dict(x=mx.array(5))) ). - py::object positional_grads; - py::object keyword_grads; - py::object py_grads; + nb::object positional_grads; + nb::object keyword_grads; + nb::object py_grads; // Collect the gradients for the positional arguments if (argnums.size() == 1) { positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); } else if (argnums.size() > 1) { - py::tuple grads_(argnums.size()); + nb::list grads_; for (int i = 0; i < argnums.size(); i++) { - grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]); + grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i])); } - positional_grads = py::cast(grads_); + positional_grads = nb::tuple(grads_); } else { - positional_grads = py::none(); + positional_grads = nb::none(); } // No keyword argument gradients so return the tuple of gradients if (argnames.size() == 0) { py_grads = positional_grads; } else { - py::dict grads_; + nb::dict grads_; for (int i = 0; i < argnames.size(); i++) { auto& k = argnames[i]; grads_[k.c_str()] = tree_unflatten( kwargs[k.c_str()], gradients, counts[i + argnums.size()]); } - keyword_grads = py::cast(grads_); + keyword_grads = grads_; - py_grads = - py::cast(py::make_tuple(positional_grads, keyword_grads)); + py_grads = nb::make_tuple(positional_grads, keyword_grads); } // Put the values back in the container - py::object return_value = tree_unflatten(py_value_out, value); + nb::object return_value = tree_unflatten(py_value_out, value); return std::make_pair(return_value, py_grads); }; } auto py_vmap( - const py::function& fun, - const py::object& in_axes, - const py::object& out_axes) { - return [fun, in_axes, out_axes](const py::args& args) { - auto axes_to_flat_tree = [](const py::object& tree, - const py::object& axes) { + const nb::callable& fun, + const nb::object& in_axes, + const nb::object& out_axes) { + return [fun, in_axes, out_axes](const nb::args& args) { + auto axes_to_flat_tree = [](const nb::object& tree, + const nb::object& axes) { auto tree_axes = tree_map( {tree, axes}, - [](const std::vector& inputs) { return inputs[1]; }); + [](const std::vector& inputs) { return inputs[1]; }); std::vector flat_axes; - tree_visit(tree_axes, [&flat_axes](py::handle obj) { + tree_visit(tree_axes, [&flat_axes](nb::handle obj) { if (obj.is_none()) { flat_axes.push_back(-1); - } else if (py::isinstance(obj)) { - flat_axes.push_back(py::cast(py::cast(obj))); + } else if (nb::isinstance(obj)) { + flat_axes.push_back(nb::cast(nb::cast(obj))); } else { throw std::invalid_argument("[vmap] axis must be int or None."); } @@ -280,7 +288,7 @@ auto py_vmap( // py_value_out will hold the output of the python function in order to be // able to reconstruct the python tree of extra return values - py::object py_outputs; + nb::object py_outputs; auto vmap_fn = [&fun, &args, &inputs, &py_outputs](const std::vector& a) { @@ -305,24 +313,24 @@ auto py_vmap( }; } -std::unordered_map& tree_cache() { +std::unordered_map& tree_cache() { // This map is used to Cache the tree structure of the outputs - static std::unordered_map tree_cache_; + static std::unordered_map tree_cache_; return tree_cache_; } struct PyCompiledFun { - py::function fun; + nb::callable fun; size_t fun_id; - py::object captured_inputs; - py::object captured_outputs; + nb::object captured_inputs; + nb::object captured_outputs; bool shapeless; - size_t num_outputs{0}; + mutable size_t num_outputs{0}; PyCompiledFun( - const py::function& fun, - py::object inputs, - py::object outputs, + const nb::callable& fun, + nb::object inputs, + nb::object outputs, bool shapeless) : fun(fun), fun_id(reinterpret_cast(fun.ptr())), @@ -342,7 +350,7 @@ struct PyCompiledFun { num_outputs = other.num_outputs; }; - py::object operator()(const py::args& args, const py::kwargs& kwargs) { + nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { // Flat array inputs std::vector inputs; @@ -358,45 +366,45 @@ struct PyCompiledFun { constexpr uint64_t dict_identifier = 18446744073709551521UL; // Flatten the tree with hashed constants and structure - std::function recurse; - recurse = [&](py::handle obj) { - if (py::isinstance(obj)) { - auto l = py::cast(obj); + std::function recurse; + recurse = [&](nb::handle obj) { + if (nb::isinstance(obj)) { + auto l = nb::cast(obj); constants.push_back(list_identifier); for (int i = 0; i < l.size(); ++i) { recurse(l[i]); } - } else if (py::isinstance(obj)) { - auto l = py::cast(obj); + } else if (nb::isinstance(obj)) { + auto l = nb::cast(obj); constants.push_back(list_identifier); for (auto item : obj) { recurse(item); } - } else if (py::isinstance(obj)) { - auto d = py::cast(obj); + } else if (nb::isinstance(obj)) { + auto d = nb::cast(obj); constants.push_back(dict_identifier); for (auto item : d) { - auto r = py::hash(item.first); + auto r = item.first.attr("__hash__"); constants.push_back(*reinterpret_cast(&r)); recurse(item.second); } - } else if (py::isinstance(obj)) { - inputs.push_back(py::cast(obj)); + } else if (nb::isinstance(obj)) { + inputs.push_back(nb::cast(obj)); constants.push_back(array_identifier); - } else if (py::isinstance(obj)) { - auto r = py::hash(obj); + } else if (nb::isinstance(obj)) { + auto r = obj.attr("__hash__"); constants.push_back(*reinterpret_cast(&r)); - } else if (py::isinstance(obj)) { - auto r = obj.cast(); + } else if (nb::isinstance(obj)) { + auto r = nb::cast(obj); constants.push_back(*reinterpret_cast(&r)); - } else if (py::isinstance(obj)) { - auto r = obj.cast(); + } else if (nb::isinstance(obj)) { + auto r = nb::cast(obj); constants.push_back(*reinterpret_cast(&r)); } else { std::ostringstream msg; msg << "[compile] Function arguments must be trees of arrays " << "or constants (floats, ints, or strings), but received " - << "type " << obj.get_type() << "."; + << "type " << type_name_str(obj) << "."; throw std::invalid_argument(msg.str()); } }; @@ -404,13 +412,12 @@ struct PyCompiledFun { recurse(args); int num_args = inputs.size(); recurse(kwargs); - auto compile_fun = [this, &args, &kwargs, num_args]( const std::vector& a) { // Put tracers into captured inputs std::vector flat_in_captures; std::vector trace_captures; - if (!py::isinstance(captured_inputs)) { + if (!captured_inputs.is_none()) { flat_in_captures = tree_flatten(captured_inputs, false); trace_captures.insert( trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); @@ -425,7 +432,7 @@ struct PyCompiledFun { tree_cache().insert({fun_id, py_outputs}); num_outputs = outputs.size(); - if (!py::isinstance(captured_outputs)) { + if (!captured_outputs.is_none()) { auto flat_out_captures = tree_flatten(captured_outputs, false); outputs.insert( outputs.end(), @@ -434,13 +441,13 @@ struct PyCompiledFun { } // Replace tracers with originals in captured inputs - if (!py::isinstance(captured_inputs)) { + if (!captured_inputs.is_none()) { tree_replace(captured_inputs, trace_captures, flat_in_captures); } return outputs; }; - if (!py::isinstance(captured_inputs)) { + if (!captured_inputs.is_none()) { auto flat_in_captures = tree_flatten(captured_inputs, false); inputs.insert( inputs.end(), @@ -451,7 +458,7 @@ struct PyCompiledFun { // Compile and call auto outputs = detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); - if (!py::isinstance(captured_outputs)) { + if (!captured_outputs.is_none()) { std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), std::make_move_iterator(outputs.end())); @@ -459,12 +466,16 @@ struct PyCompiledFun { } // Put the outputs back in the container - py::object py_outputs = tree_cache().at(fun_id); + nb::object py_outputs = tree_cache().at(fun_id); return tree_unflatten_from_structure(py_outputs, outputs); + } + + nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { + return const_cast(this)->call_impl(args, kwargs); }; ~PyCompiledFun() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; tree_cache().erase(fun_id); detail::compile_erase(fun_id); @@ -476,35 +487,35 @@ struct PyCompiledFun { class PyCheckpointedFun { public: - PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {} + PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {} ~PyCheckpointedFun() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; fun_.release().dec_ref(); } struct InnerFunction { - py::object fun_; - py::object args_structure_; - std::weak_ptr output_structure_; + nb::object fun_; + nb::object args_structure_; + std::weak_ptr output_structure_; InnerFunction( - py::object fun, - py::object args_structure, - std::weak_ptr output_structure) + nb::object fun, + nb::object args_structure, + std::weak_ptr output_structure) : fun_(std::move(fun)), args_structure_(std::move(args_structure)), output_structure_(output_structure) {} ~InnerFunction() { - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; fun_.release().dec_ref(); args_structure_.release().dec_ref(); } std::vector operator()(const std::vector& inputs) { - auto args = py::cast( + auto args = nb::cast( tree_unflatten_from_structure(args_structure_, inputs)); auto [outputs, output_structure] = tree_flatten_with_structure(fun_(*args[0], **args[1]), false); @@ -515,9 +526,9 @@ class PyCheckpointedFun { } }; - py::object operator()(const py::args& args, const py::kwargs& kwargs) { - auto output_structure = std::make_shared(); - auto full_args = py::make_tuple(args, kwargs); + nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { + auto output_structure = std::make_shared(); + auto full_args = nb::make_tuple(args, kwargs); auto [inputs, args_structure] = tree_flatten_with_structure(full_args, false); @@ -527,26 +538,27 @@ class PyCheckpointedFun { return tree_unflatten_from_structure(*output_structure, outputs); } + nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { + return const_cast(this)->call_impl(args, kwargs); + } + private: - py::function fun_; + nb::callable fun_; }; -void init_transforms(py::module_& m) { - py::options options; - options.disable_function_signatures(); - +void init_transforms(nb::module_& m) { m.def( "eval", - [](const py::args& args) { + [](const nb::args& args) { std::vector arrays = tree_flatten(args, false); { - py::gil_scoped_release nogil; + nb::gil_scoped_release nogil; eval(arrays); } }, + nb::arg(), + nb::sig("def eval(*args) -> None"), R"pbdoc( - eval(*args) -> None - Evaluate an :class:`array` or tree of :class:`array`. Args: @@ -557,19 +569,15 @@ void init_transforms(py::module_& m) { )pbdoc"); m.def( "jvp", - [](const py::function& fun, + [](const nb::callable& fun, const std::vector& primals, const std::vector& tangents) { auto vfun = [&fun](const std::vector& primals) { - py::args args = py::tuple(primals.size()); - for (int i = 0; i < primals.size(); ++i) { - args[i] = primals[i]; - } - auto out = fun(*args); - if (py::isinstance(out)) { - return std::vector{py::cast(out)}; + auto out = fun(*nb::cast(primals)); + if (nb::isinstance(out)) { + return std::vector{nb::cast(out)}; } else { - return py::cast>(out); + return nb::cast>(out); } }; return jvp(vfun, primals, tangents); @@ -577,17 +585,16 @@ void init_transforms(py::module_& m) { "fun"_a, "primals"_a, "tangents"_a, + nb::sig( + "def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"), R"pbdoc( - jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]] - - Compute the Jacobian-vector product. This computes the product of the Jacobian of a function ``fun`` evaluated at ``primals`` with the ``tangents``. Args: - fun (function): A function which takes a variable number of :class:`array` + fun (callable): A function which takes a variable number of :class:`array` and returns a single :class:`array` or list of :class:`array`. primals (list(array)): A list of :class:`array` at which to evaluate the Jacobian. @@ -601,19 +608,15 @@ void init_transforms(py::module_& m) { )pbdoc"); m.def( "vjp", - [](const py::function& fun, + [](const nb::callable& fun, const std::vector& primals, const std::vector& cotangents) { auto vfun = [&fun](const std::vector& primals) { - py::args args = py::tuple(primals.size()); - for (int i = 0; i < primals.size(); ++i) { - args[i] = primals[i]; - } - auto out = fun(*args); - if (py::isinstance(out)) { - return std::vector{py::cast(out)}; + auto out = fun(*nb::cast(primals)); + if (nb::isinstance(out)) { + return std::vector{nb::cast(out)}; } else { - return py::cast>(out); + return nb::cast>(out); } }; return vjp(vfun, primals, cotangents); @@ -621,16 +624,16 @@ void init_transforms(py::module_& m) { "fun"_a, "primals"_a, "cotangents"_a, + nb::sig( + "def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"), R"pbdoc( - vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]] - Compute the vector-Jacobian product. Computes the product of the ``cotangents`` with the Jacobian of a function ``fun`` evaluated at ``primals``. Args: - fun (function): A function which takes a variable number of :class:`array` + fun (callable): A function which takes a variable number of :class:`array` and returns a single :class:`array` or list of :class:`array`. primals (list(array)): A list of :class:`array` at which to evaluate the Jacobian. @@ -644,20 +647,20 @@ void init_transforms(py::module_& m) { )pbdoc"); m.def( "value_and_grad", - [](const py::function& fun, + [](const nb::callable& fun, const std::optional& argnums, const StrOrVec& argnames) { auto [argnums_vec, argnames_vec] = validate_argnums_argnames(argnums, argnames); - return py::cpp_function(py_value_and_grad( + return nb::cpp_function(py_value_and_grad( fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); }, "fun"_a, - "argnums"_a = std::nullopt, + "argnums"_a = nb::none(), "argnames"_a = std::vector{}, + nb::sig( + "def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), R"pbdoc( - value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function - Returns a function which computes the value and gradient of ``fun``. The function passed to :func:`value_and_grad` should return either @@ -688,7 +691,7 @@ void init_transforms(py::module_& m) { (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: - fun (function): A function which takes a variable number of + fun (callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a scalar output :class:`array` or a tuple the first element of which should be a scalar :class:`array`. @@ -702,34 +705,34 @@ void init_transforms(py::module_& m) { no gradients for keyword arguments by default. Returns: - function: A function which returns a tuple where the first element + callable: A function which returns a tuple where the first element is the output of `fun` and the second element is the gradients w.r.t. the loss. )pbdoc"); m.def( "grad", - [](const py::function& fun, + [](const nb::callable& fun, const std::optional& argnums, const StrOrVec& argnames) { auto [argnums_vec, argnames_vec] = validate_argnums_argnames(argnums, argnames); auto fn = py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); - return py::cpp_function( - [fn](const py::args& args, const py::kwargs& kwargs) { + return nb::cpp_function( + [fn](const nb::args& args, const nb::kwargs& kwargs) { return fn(args, kwargs).second; }); }, "fun"_a, - "argnums"_a = std::nullopt, + "argnums"_a = nb::none(), "argnames"_a = std::vector{}, + nb::sig( + "def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"), R"pbdoc( - grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function - Returns a function which computes the gradient of ``fun``. Args: - fun (function): A function which takes a variable number of + fun (callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a scalar output :class:`array`. argnums (int or list(int), optional): Specify the index (or indices) @@ -742,26 +745,26 @@ void init_transforms(py::module_& m) { no gradients for keyword arguments by default. Returns: - function: A function which has the same input arguments as ``fun`` and + callable: A function which has the same input arguments as ``fun`` and returns the gradient(s). )pbdoc"); m.def( "vmap", - [](const py::function& fun, - const py::object& in_axes, - const py::object& out_axes) { - return py::cpp_function(py_vmap(fun, in_axes, out_axes)); + [](const nb::callable& fun, + const nb::object& in_axes, + const nb::object& out_axes) { + return nb::cpp_function(py_vmap(fun, in_axes, out_axes)); }, "fun"_a, "in_axes"_a = 0, "out_axes"_a = 0, + nb::sig( + "def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"), R"pbdoc( - vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function - Returns a vectorized version of ``fun``. Args: - fun (function): A function which takes a variable number of + fun (callable): A function which takes a variable number of :class:`array` or a tree of :class:`array` and returns a variable number of :class:`array` or a tree of :class:`array`. in_axes (int, optional): An integer or a valid prefix tree of the @@ -774,16 +777,16 @@ void init_transforms(py::module_& m) { Defaults to ``0``. Returns: - function: The vectorized function. + callable: The vectorized function. )pbdoc"); m.def( "export_to_dot", - [](py::object file, const py::args& args) { + [](nb::object file, const nb::args& args) { std::vector arrays = tree_flatten(args); - if (py::isinstance(file)) { - std::ofstream out(py::cast(file)); + if (nb::isinstance(file)) { + std::ofstream out(nb::cast(file)); export_to_dot(out, arrays); - } else if (py::hasattr(file, "write")) { + } else if (nb::hasattr(file, "write")) { std::ostringstream out; export_to_dot(out, arrays); auto write = file.attr("write"); @@ -793,57 +796,50 @@ void init_transforms(py::module_& m) { "export_to_dot accepts file-like objects or strings to be used as filenames"); } }, - "file"_a); + "file"_a, + "args"_a); m.def( "compile", - [](const py::function& fun, - const py::object& inputs, - const py::object& outputs, + [](const nb::callable& fun, + const nb::object& inputs, + const nb::object& outputs, bool shapeless) { - py::options options; - options.disable_function_signatures(); - - std::ostringstream doc; - auto name = fun.attr("__name__").cast(); - doc << name; + // Try to get the name + auto n = fun.attr("__name__"); + auto name = n.is_none() ? "compiled" : nb::cast(n); // Try to get the signature - auto inspect = py::module::import("inspect"); - if (!inspect.attr("isbuiltin")(fun).cast()) { - doc << inspect.attr("signature")(fun) - .attr("__str__")() - .cast(); + std::ostringstream sig; + sig << "def " << name; + auto inspect = nb::module_::import_("inspect"); + if (nb::cast(inspect.attr("isroutine")(fun))) { + sig << nb::cast( + inspect.attr("signature")(fun).attr("__str__")()); + } else { + sig << "(*args, **kwargs)"; } // Try to get the doc string - if (auto d = fun.attr("__doc__"); py::isinstance(d)) { - doc << "\n\n"; - auto dstr = d.cast(); - // Add spaces to match first line indentation with remainder of - // docstring - int i = 0; - for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) { - doc << ' '; - } - doc << dstr; - } - auto doc_str = doc.str(); - return py::cpp_function( + auto d = inspect.attr("getdoc")(fun); + std::string doc = + d.is_none() ? "MLX compiled function." : nb::cast(d); + + auto sig_str = sig.str(); + return nb::cpp_function( PyCompiledFun{fun, inputs, outputs, shapeless}, - py::name(name.c_str()), - py::doc(doc_str.c_str())); + nb::name(name.c_str()), + nb::sig(sig_str.c_str()), + doc.c_str()); }, "fun"_a, - "inputs"_a = std::nullopt, - "outputs"_a = std::nullopt, + "inputs"_a = nb::none(), + "outputs"_a = nb::none(), "shapeless"_a = false, R"pbdoc( - compile(fun: function) -> function - Returns a compiled function which produces the same output as ``fun``. Args: - fun (function): A function which takes a variable number of + fun (callable): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a variable number of :class:`array` or trees of :class:`array`. inputs (list or dict, optional): These inputs will be captured during @@ -864,15 +860,13 @@ void init_transforms(py::module_& m) { ``shapeless`` set to ``True``. Default: ``False`` Returns: - function: A compiled function which has the same input arguments + callable: A compiled function which has the same input arguments as ``fun`` and returns the the same output(s). )pbdoc"); m.def( "disable_compile", &disable_compile, R"pbdoc( - disable_compile() -> None - Globally disable compilation. Setting the environment variable ``MLX_DISABLE_COMPILE`` can also be used to disable compilation. )pbdoc"); @@ -880,17 +874,15 @@ void init_transforms(py::module_& m) { "enable_compile", &enable_compile, R"pbdoc( - enable_compile() -> None - Globally enable compilation. This will override the environment variable ``MLX_DISABLE_COMPILE`` if set. )pbdoc"); m.def( "checkpoint", - [](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); }, + [](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); }, "fun"_a); // Register static Python object cleanup before the interpreter exits - auto atexit = py::module_::import("atexit"); - atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); + auto atexit = nb::module_::import_("atexit"); + atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); })); } diff --git a/python/src/trees.cpp b/python/src/trees.cpp index bd2c3f975..29fe9d4bb 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -2,16 +2,16 @@ #include "python/src/trees.h" -void tree_visit(py::object tree, std::function visitor) { - std::function recurse; - recurse = [&](py::handle subtree) { - if (py::isinstance(subtree) || - py::isinstance(subtree)) { +void tree_visit(nb::object tree, std::function visitor) { + std::function recurse; + recurse = [&](nb::handle subtree) { + if (nb::isinstance(subtree) || + nb::isinstance(subtree)) { for (auto item : subtree) { recurse(item); } - } else if (py::isinstance(subtree)) { - for (auto item : py::cast(subtree)) { + } else if (nb::isinstance(subtree)) { + for (auto item : nb::cast(subtree)) { recurse(item.second); } } else { @@ -23,63 +23,63 @@ void tree_visit(py::object tree, std::function visitor) { } template -void validate_subtrees(const std::vector& subtrees) { - int len = py::cast(subtrees[0]).size(); +void validate_subtrees(const std::vector& subtrees) { + int len = nb::cast(subtrees[0]).size(); for (auto& subtree : subtrees) { - if ((py::isinstance(subtree) && py::cast(subtree).size() != len) || - py::isinstance(subtree) || py::isinstance(subtree)) { + if ((nb::isinstance(subtree) && nb::cast(subtree).size() != len) || + nb::isinstance(subtree) || nb::isinstance(subtree)) { throw std::invalid_argument( "[tree_map] Additional input tree is not a valid prefix of the first tree."); } } } -py::object tree_map( - const std::vector& trees, - std::function&)> transform) { - std::function&)> recurse; +nb::object tree_map( + const std::vector& trees, + std::function&)> transform) { + std::function&)> recurse; - recurse = [&](const std::vector& subtrees) { - if (py::isinstance(subtrees[0])) { - py::list l; - std::vector items(subtrees.size()); - validate_subtrees(subtrees); - for (int i = 0; i < py::cast(subtrees[0]).size(); ++i) { + recurse = [&](const std::vector& subtrees) { + if (nb::isinstance(subtrees[0])) { + nb::list l; + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - items[j] = py::cast(subtrees[j])[i]; + if (nb::isinstance(subtrees[j])) { + items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } l.append(recurse(items)); } - return py::cast(l); - } else if (py::isinstance(subtrees[0])) { + return nb::cast(l); + } else if (nb::isinstance(subtrees[0])) { // Check the rest of the subtrees - std::vector items(subtrees.size()); - int len = py::cast(subtrees[0]).size(); - py::tuple l(len); - validate_subtrees(subtrees); + std::vector items(subtrees.size()); + int len = nb::cast(subtrees[0]).size(); + nb::list l; + validate_subtrees(subtrees); for (int i = 0; i < len; ++i) { for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - items[j] = py::cast(subtrees[j])[i]; + if (nb::isinstance(subtrees[j])) { + items[j] = nb::cast(subtrees[j])[i]; } else { items[j] = subtrees[j]; } } - l[i] = recurse(items); + l.append(recurse(items)); } - return py::cast(l); - } else if (py::isinstance(subtrees[0])) { - std::vector items(subtrees.size()); - validate_subtrees(subtrees); - py::dict d; - for (auto item : py::cast(subtrees[0])) { + return nb::cast(nb::tuple(l)); + } else if (nb::isinstance(subtrees[0])) { + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + nb::dict d; + for (auto item : nb::cast(subtrees[0])) { for (int j = 0; j < subtrees.size(); ++j) { - if (py::isinstance(subtrees[j])) { - auto subdict = py::cast(subtrees[j]); + if (nb::isinstance(subtrees[j])) { + auto subdict = nb::cast(subtrees[j]); if (!subdict.contains(item.first)) { throw std::invalid_argument( "[tree_map] Tree is not a valid prefix tree of the first tree."); @@ -91,7 +91,7 @@ py::object tree_map( } d[item.first] = recurse(items); } - return py::cast(d); + return nb::cast(d); } else { return transform(subtrees); } @@ -99,40 +99,40 @@ py::object tree_map( return recurse(trees); } -py::object tree_map( - py::object tree, - std::function transform) { - return tree_map({tree}, [&](std::vector inputs) { +nb::object tree_map( + nb::object tree, + std::function transform) { + return tree_map({tree}, [&](std::vector inputs) { return transform(inputs[0]); }); } void tree_visit_update( - py::object tree, - std::function visitor) { - std::function recurse; - recurse = [&](py::handle subtree) { - if (py::isinstance(subtree)) { - auto l = py::cast(subtree); + nb::object tree, + std::function visitor) { + std::function recurse; + recurse = [&](nb::handle subtree) { + if (nb::isinstance(subtree)) { + auto l = nb::cast(subtree); for (int i = 0; i < l.size(); ++i) { l[i] = recurse(l[i]); } - return py::cast(l); - } else if (py::isinstance(subtree)) { + return nb::cast(l); + } else if (nb::isinstance(subtree)) { for (auto item : subtree) { recurse(item); } - return py::cast(subtree); - } else if (py::isinstance(subtree)) { - auto d = py::cast(subtree); + return nb::cast(subtree); + } else if (nb::isinstance(subtree)) { + auto d = nb::cast(subtree); for (auto item : d) { d[item.first] = recurse(item.second); } - return py::cast(d); - } else if (py::isinstance(subtree)) { + return nb::cast(d); + } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { - return py::cast(subtree); + return nb::cast(subtree); } }; recurse(tree); @@ -141,36 +141,36 @@ void tree_visit_update( // Fill a pytree (recursive dict or list of dict or list) // in place with the given arrays // Non dict or list nodes are ignored -void tree_fill(py::object& tree, const std::vector& values) { +void tree_fill(nb::object& tree, const std::vector& values) { size_t index = 0; tree_visit_update( - tree, [&](py::handle node) { return py::cast(values[index++]); }); + tree, [&](nb::handle node) { return nb::cast(values[index++]); }); } // Replace all the arrays from the src values with the dst values in the tree void tree_replace( - py::object& tree, + nb::object& tree, const std::vector& src, const std::vector& dst) { std::unordered_map src_to_dst; for (int i = 0; i < src.size(); ++i) { src_to_dst.insert({src[i].id(), dst[i]}); } - tree_visit_update(tree, [&](py::handle node) { - auto arr = py::cast(node); + tree_visit_update(tree, [&](nb::handle node) { + auto arr = nb::cast(node); if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { - return py::cast(it->second); + return nb::cast(it->second); } - return py::cast(arr); + return nb::cast(arr); }); } -std::vector tree_flatten(py::object tree, bool strict /* = true */) { +std::vector tree_flatten(nb::object tree, bool strict /* = true */) { std::vector flat_tree; - tree_visit(tree, [&](py::handle obj) { - if (py::isinstance(obj)) { - flat_tree.push_back(py::cast(obj)); + tree_visit(tree, [&](nb::handle obj) { + if (nb::isinstance(obj)) { + flat_tree.push_back(nb::cast(obj)); } else if (strict) { throw std::invalid_argument( "[tree_flatten] The argument should contain only arrays"); @@ -180,24 +180,24 @@ std::vector tree_flatten(py::object tree, bool strict /* = true */) { return flat_tree; } -py::object tree_unflatten( - py::object tree, +nb::object tree_unflatten( + nb::object tree, const std::vector& values, int index /* = 0 */) { - return tree_map(tree, [&](py::handle obj) { - if (py::isinstance(obj)) { - return py::cast(values[index++]); + return tree_map(tree, [&](nb::handle obj) { + if (nb::isinstance(obj)) { + return nb::cast(values[index++]); } else { - return py::cast(obj); + return nb::cast(obj); } }); } -py::object structure_sentinel() { - static py::object sentinel; +nb::object structure_sentinel() { + static nb::object sentinel; if (sentinel.ptr() == nullptr) { - sentinel = py::capsule(&sentinel); + sentinel = nb::capsule(&sentinel); // probably not needed but this should make certain that we won't ever // delete the sentinel sentinel.inc_ref(); @@ -206,19 +206,19 @@ py::object structure_sentinel() { return sentinel; } -std::pair, py::object> tree_flatten_with_structure( - py::object tree, +std::pair, nb::object> tree_flatten_with_structure( + nb::object tree, bool strict /* = true */) { auto sentinel = structure_sentinel(); std::vector flat_tree; auto structure = tree_map( tree, - [&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) { - if (py::isinstance(obj)) { - flat_tree.push_back(py::cast(obj)); + [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { + if (nb::isinstance(obj)) { + flat_tree.push_back(nb::cast(obj)); return sentinel; } else if (!strict) { - return py::cast(obj); + return nb::cast(obj); } else { throw std::invalid_argument( "[tree_flatten] The argument should contain only arrays"); @@ -228,16 +228,16 @@ std::pair, py::object> tree_flatten_with_structure( return {flat_tree, structure}; } -py::object tree_unflatten_from_structure( - py::object structure, +nb::object tree_unflatten_from_structure( + nb::object structure, const std::vector& values, int index /* = 0 */) { auto sentinel = structure_sentinel(); - return tree_map(structure, [&](py::handle obj) { + return tree_map(structure, [&](nb::handle obj) { if (obj.is(sentinel)) { - return py::cast(values[index++]); + return nb::cast(values[index++]); } else { - return py::cast(obj); + return nb::cast(obj); } }); } diff --git a/python/src/trees.h b/python/src/trees.h index bb44f2320..44d9d9b0e 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -1,38 +1,37 @@ // Copyright © 2023-2024 Apple Inc. #pragma once -#include -#include +#include #include "mlx/array.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlx::core; -void tree_visit(py::object tree, std::function visitor); +void tree_visit(nb::object tree, std::function visitor); -py::object tree_map( - const std::vector& trees, - std::function&)> transform); +nb::object tree_map( + const std::vector& trees, + std::function&)> transform); -py::object tree_map( - py::object tree, - std::function transform); +nb::object tree_map( + nb::object tree, + std::function transform); void tree_visit_update( - py::object tree, - std::function visitor); + nb::object tree, + std::function visitor); /** * Fill a pytree (recursive dict or list of dict or list) in place with the * given arrays. */ -void tree_fill(py::object& tree, const std::vector& values); +void tree_fill(nb::object& tree, const std::vector& values); /** * Replace all the arrays from the src values with the dst values in the * tree. */ void tree_replace( - py::object& tree, + nb::object& tree, const std::vector& src, const std::vector& dst); @@ -40,21 +39,21 @@ void tree_replace( * Flatten a tree into a vector of arrays. If strict is true, then the * function will throw if the tree contains a leaf which is not an array. */ -std::vector tree_flatten(py::object tree, bool strict = true); +std::vector tree_flatten(nb::object tree, bool strict = true); /** * Unflatten a tree from a vector of arrays. */ -py::object tree_unflatten( - py::object tree, +nb::object tree_unflatten( + nb::object tree, const std::vector& values, int index = 0); -std::pair, py::object> tree_flatten_with_structure( - py::object tree, +std::pair, nb::object> tree_flatten_with_structure( + nb::object tree, bool strict = true); -py::object tree_unflatten_from_structure( - py::object structure, +nb::object tree_unflatten_from_structure( + nb::object structure, const std::vector& values, int index = 0); diff --git a/python/src/utils.cpp b/python/src/utils.cpp deleted file mode 100644 index c07016709..000000000 --- a/python/src/utils.cpp +++ /dev/null @@ -1,81 +0,0 @@ - -#include "mlx/utils.h" -#include -#include -#include - -namespace py = pybind11; -using namespace py::literals; -using namespace mlx::core; - -// Slightly different from the original, with python context on init we are not -// in the context yet. Only create the inner context on enter then delete on -// exit. -class PyStreamContext { - public: - PyStreamContext(StreamOrDevice s) : _inner(nullptr) { - if (std::holds_alternative(s)) { - throw std::runtime_error( - "[StreamContext] Invalid argument, please specify a stream or device."); - } - _s = s; - } - - void enter() { - _inner = new StreamContext(_s); - } - - void exit() { - if (_inner != nullptr) { - delete _inner; - _inner = nullptr; - } - } - - private: - StreamOrDevice _s; - StreamContext* _inner; -}; - -void init_utils(py::module_& m) { - py::class_(m, "StreamContext", R"pbdoc( - A context manager for setting the current device and stream. - - See :func:`stream` for usage. - - Args: - s: The stream or device to set as the default. - )pbdoc") - .def(py::init(), "s"_a) - .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) - .def( - "__exit__", - [](PyStreamContext& scm, - const std::optional& exc_type, - const std::optional& exc_value, - const std::optional& traceback) { scm.exit(); }); - m.def( - "stream", - [](StreamOrDevice s) { return PyStreamContext(s); }, - "s"_a, - R"pbdoc( - Create a context manager to set the default device and stream. - - Args: - s: The :obj:`Stream` or :obj:`Device` to set as the default. - - Returns: - A context manager that sets the default device and stream. - - Example: - - .. code-block::python - - import mlx.core as mx - - # Create a context manager for the default device and stream. - with mx.stream(mx.cpu): - # Operations here will use mx.cpu by default. - pass - )pbdoc"); -} diff --git a/python/src/utils.h b/python/src/utils.h index 5ac878979..8b52cba12 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -1,23 +1,22 @@ -// Copyright © 2023 Apple Inc. - +// Copyright © 2023-2024 Apple Inc. #pragma once #include +#include #include -#include -#include -#include +#include +#include +#include #include "mlx/array.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlx::core; using IntOrVec = std::variant>; using ScalarOrArray = std:: - variant, py::object>; -static constexpr std::monostate none{}; + variant, nb::object>; inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { std::vector axes; @@ -32,31 +31,36 @@ inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { return axes; } -inline array to_array_with_accessor(py::object obj) { - if (py::hasattr(obj, "__mlx_array__")) { - return obj.attr("__mlx_array__")().cast(); +inline array to_array_with_accessor(nb::object obj) { + if (nb::hasattr(obj, "__mlx_array__")) { + return nb::cast(obj.attr("__mlx_array__")()); + } else if (nb::isinstance(obj)) { + return nb::cast(obj); } else { - return obj.cast(); + std::ostringstream msg; + msg << "Invalid type " << nb::type_name(obj.type()).c_str() + << " received in array initialization."; + throw std::invalid_argument(msg.str()); } } inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { - if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), dtype.value_or(bool_)); - } else if (auto pv = std::get_if(&v); pv) { + if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), dtype.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(int32); // bool_ is an exception and is always promoted - return array(py::cast(*pv), (out_t == bool_) ? int32 : out_t); - } else if (auto pv = std::get_if(&v); pv) { + return array(nb::cast(*pv), (out_t == bool_) ? int32 : out_t); + } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(float32); return array( - py::cast(*pv), is_floating_point(out_t) ? out_t : float32); + nb::cast(*pv), is_floating_point(out_t) ? out_t : float32); } else if (auto pv = std::get_if>(&v); pv) { return array(static_cast(*pv), complex64); } else { - return to_array_with_accessor(std::get(v)); + return to_array_with_accessor(std::get(v)); } } @@ -68,14 +72,14 @@ inline std::pair to_arrays( // - If a is an array but b is not, treat b as a weak python type // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone - if (auto pa = std::get_if(&a); pa) { + if (auto pa = std::get_if(&a); pa) { auto arr_a = to_array_with_accessor(*pa); - if (auto pb = std::get_if(&b); pb) { + if (auto pb = std::get_if(&b); pb) { auto arr_b = to_array_with_accessor(*pb); return {arr_a, arr_b}; } return {arr_a, to_array(b, arr_a.dtype())}; - } else if (auto pb = std::get_if(&b); pb) { + } else if (auto pb = std::get_if(&b); pb) { auto arr_b = to_array_with_accessor(*pb); return {to_array(a, arr_b.dtype()), arr_b}; } else { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 07c7bd18d..ebbe25806 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(y.dtype, mx.bool_) self.assertEqual(y.item(), True) - # y = mx.array(x, mx.complex64) - # self.assertEqual(y.dtype, mx.complex64) - # self.assertEqual(y.item(), 3.0+0j) + y = mx.array(x, mx.complex64) + self.assertEqual(y.dtype, mx.complex64) + self.assertEqual(y.item(), 3.0 + 0j) def test_array_repr(self): x = mx.array(True) @@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase): # check if it throws an error when dtype is not supported (bfloat16) x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16) - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): pickle.dumps(x) def test_array_copy(self): @@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqualArray(y, x - 1) def test_indexing(self): + # Only ellipsis is a no-op + a_mlx = mx.array([1])[...] + self.assertEqual(a_mlx.shape, (1,)) + self.assertEqual(a_mlx.item(), 1) + # Basic content check, slice indexing a_npy = np.arange(64, dtype=np.float32) a_mlx = mx.array(a_npy) @@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase): for mlx_dtype, tf_dtype, np_dtype in dtypes_list: a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype) a_tf = tf.constant(a_np, dtype=tf_dtype) - a_mx = mx.array(a_tf) + a_mx = mx.array(np.array(a_tf)) for f in [ lambda x: x, lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T, diff --git a/setup.py b/setup.py index 7abdcb566..5623f1e65 100644 --- a/setup.py +++ b/setup.py @@ -134,7 +134,9 @@ class GenerateStubs(Command): pass def run(self) -> None: - subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"]) + subprocess.run( + ["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"] + ) # Read the content of README.md @@ -165,7 +167,7 @@ if __name__ == "__main__": include_package_data=True, extras_require={ "testing": ["numpy", "torch"], - "dev": ["pre-commit", "pybind11-stubgen"], + "dev": ["pre-commit"], }, ext_modules=[CMakeExtension("mlx.core")], cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},