Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun 2024-03-18 20:12:25 -07:00 committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 2343 additions and 2344 deletions

View File

@ -31,8 +31,7 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install pybind11-stubgen
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@ -44,7 +43,8 @@ jobs:
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
python3 setup.py generate_stubs echo "stubs"
python -m nanobind.stubgen -m mlx.core -r -O python
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -80,8 +80,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@ -95,7 +94,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python setup.py generate_stubs python -m nanobind.stubgen -m mlx.core -r -O python
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
@ -144,9 +143,8 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install twine pip install twine
pip install build pip install build
@ -161,7 +159,7 @@ jobs:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
python setup.py generate_stubs python -m nanobind.stubgen -m mlx.core -r -O python
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
@ -209,9 +207,8 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
@ -219,7 +216,7 @@ jobs:
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v pip install . -v
python setup.py generate_stubs python -m nanobind.stubgen -m mlx.core -r -O python
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel python -m build --wheel

View File

@ -146,8 +146,12 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS) if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(pybind11 CONFIG REQUIRED) execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()

View File

@ -29,8 +29,8 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"} autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = { intersphinx_mapping = {
"https://docs.python.org/3": None, "python": ("https://docs.python.org/3", None),
"https://numpy.org/doc/stable/": None, "numpy": ("https://numpy.org/doc/stable/", None),
} }
templates_path = ["_templates"] templates_path = ["_templates"]
@ -59,3 +59,14 @@ html_theme_options = {
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc" htmlhelp_basename = "mlx_doc"
def setup(app):
wrapped = app.registry.documenters["function"].can_document_member
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool:
return "nanobind.nb_func" in str(type(member)) or wrapped(
member, *args, **kwargs
)
app.registry.documenters["function"].can_document_member = nanobind_function_patch

View File

@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_ Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell .. code-block:: shell
pip install "pybind11[global]" pip install git+https://github.com/wjakob/nanobind.git
conda install pybind11
brew install pybind11
Then simply build and install it using pip: Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
@ -122,8 +121,6 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length()); out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
out_stream->write(header.str().c_str(), header.str().length()); out_stream->write(header.str().c_str(), header.str().length());
out_stream->write(a.data<char>(), a.nbytes()); out_stream->write(a.data<char>(), a.nbytes());
return;
} }
/** Save array to file in .npy format */ /** Save array to file in .npy format */

View File

@ -7,6 +7,25 @@
namespace mlx::core { namespace mlx::core {
struct complex64_t; struct complex64_t;
struct complex128_t;
template <typename T>
static constexpr bool can_convert_to_complex128 =
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
struct complex128_t : public std::complex<double> {
complex128_t(double v, double u) : std::complex<double>(v, u){};
complex128_t(std::complex<double> v) : std::complex<double>(v){};
template <
typename T,
typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
complex128_t(T x) : std::complex<double>(x){};
operator float() const {
return real();
};
};
template <typename T> template <typename T>
static constexpr bool can_convert_to_complex64 = static constexpr bool can_convert_to_complex64 =

View File

@ -1,3 +1,7 @@
[build-system] [build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"] requires = [
"setuptools>=42",
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
"cmake>=3.24",
]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@ -1,7 +1,10 @@
pybind11_add_module( nanobind_add_module(
core core
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@ -15,7 +18,6 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
) )
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

File diff suppressed because it is too large Load Diff

122
python/src/buffer.h Normal file
View File

@ -0,0 +1,122 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <optional>
#include <nanobind/nanobind.h>
#include "mlx/array.h"
#include "mlx/utils.h"
// Only defined in >= Python 3.9
// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3
#ifndef Py_bf_getbuffer
#define Py_bf_getbuffer 1
#define Py_bf_releasebuffer 2
#endif
namespace nb = nanobind;
using namespace mlx::core;
std::string buffer_format(const array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case bool_:
return "?";
case uint8:
return "B";
case uint16:
return "H";
case uint32:
return "I";
case uint64:
return "Q";
case int8:
return "b";
case int16:
return "h";
case int32:
return "i";
case int64:
return "q";
case float16:
return "e";
case float32:
return "f";
case bfloat16:
return "B";
case complex64:
return "Zf\0";
default: {
std::ostringstream os;
os << "bad dtype: " << a.dtype();
throw std::runtime_error(os.str());
}
}
}
struct buffer_info {
std::string format;
std::vector<ssize_t> shape;
std::vector<ssize_t> strides;
buffer_info(
const std::string& format,
std::vector<ssize_t> shape_in,
std::vector<ssize_t> strides_in)
: format(format),
shape(std::move(shape_in)),
strides(std::move(strides_in)) {}
buffer_info(const buffer_info&) = delete;
buffer_info& operator=(const buffer_info&) = delete;
buffer_info(buffer_info&& other) noexcept {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info&& rhs) noexcept {
format = std::move(rhs.format);
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
return *this;
}
};
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<array>(nb::handle(obj));
if (!a.is_evaled()) {
nb::gil_scoped_release nogil;
a.eval();
}
std::vector<ssize_t> shape(a.shape().begin(), a.shape().end());
std::vector<ssize_t> strides(a.strides().begin(), a.strides().end());
for (auto& s : strides) {
s *= a.itemsize();
}
buffer_info* info =
new buffer_info(buffer_format(a), std::move(shape), std::move(strides));
view->obj = obj;
view->ndim = a.ndim();
view->internal = info;
view->buf = a.data<void>();
view->itemsize = a.itemsize();
view->len = a.size();
view->readonly = false;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char*>(info->format.c_str());
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->strides = info->strides.data();
view->shape = info->shape.data();
}
Py_INCREF(view->obj);
return 0;
}
extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) {
delete (buffer_info*)view->internal;
}

View File

@ -1,11 +1,11 @@
// init_constants.cpp // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <limits> #include <limits>
namespace py = pybind11; namespace nb = nanobind;
void init_constants(py::module_& m) { void init_constants(nb::module_& m) {
m.attr("Inf") = std::numeric_limits<double>::infinity(); m.attr("Inf") = std::numeric_limits<double>::infinity();
m.attr("Infinity") = std::numeric_limits<double>::infinity(); m.attr("Infinity") = std::numeric_limits<double>::infinity();
m.attr("NAN") = NAN; m.attr("NAN") = NAN;
@ -19,6 +19,6 @@ void init_constants(py::module_& m) {
m.attr("inf") = std::numeric_limits<double>::infinity(); m.attr("inf") = std::numeric_limits<double>::infinity();
m.attr("infty") = std::numeric_limits<double>::infinity(); m.attr("infty") = std::numeric_limits<double>::infinity();
m.attr("nan") = NAN; m.attr("nan") = NAN;
m.attr("newaxis") = pybind11::none(); m.attr("newaxis") = nb::none();
m.attr("pi") = 3.1415926535897932384626433; m.attr("pi") = 3.1415926535897932384626433;
} }

155
python/src/convert.cpp Normal file
View File

@ -0,0 +1,155 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/stl/complex.h>
#include "python/src/convert.h"
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
template <>
struct ndarray_traits<bfloat16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
template <typename T>
array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const std::vector<int>& shape,
Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
auto data_ptr = nd_array.data();
return array(static_cast<const T*>(data_ptr), shape, dtype);
}
array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
// Compute the shape and size
std::vector<int> shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(nd_array.shape(i));
}
auto type = nd_array.dtype();
// Copy data and make array
if (type == nb::dtype<bool>()) {
return nd_array_to_mlx_contiguous<bool>(
nd_array, shape, dtype.value_or(bool_));
} else if (type == nb::dtype<uint8_t>()) {
return nd_array_to_mlx_contiguous<uint8_t>(
nd_array, shape, dtype.value_or(uint8));
} else if (type == nb::dtype<uint16_t>()) {
return nd_array_to_mlx_contiguous<uint16_t>(
nd_array, shape, dtype.value_or(uint16));
} else if (type == nb::dtype<uint32_t>()) {
return nd_array_to_mlx_contiguous<uint32_t>(
nd_array, shape, dtype.value_or(uint32));
} else if (type == nb::dtype<uint64_t>()) {
return nd_array_to_mlx_contiguous<uint64_t>(
nd_array, shape, dtype.value_or(uint64));
} else if (type == nb::dtype<int8_t>()) {
return nd_array_to_mlx_contiguous<int8_t>(
nd_array, shape, dtype.value_or(int8));
} else if (type == nb::dtype<int16_t>()) {
return nd_array_to_mlx_contiguous<int16_t>(
nd_array, shape, dtype.value_or(int16));
} else if (type == nb::dtype<int32_t>()) {
return nd_array_to_mlx_contiguous<int32_t>(
nd_array, shape, dtype.value_or(int32));
} else if (type == nb::dtype<int64_t>()) {
return nd_array_to_mlx_contiguous<int64_t>(
nd_array, shape, dtype.value_or(int64));
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
} else if (type == nb::dtype<bfloat16_t>()) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
} else if (type == nb::dtype<float>()) {
return nd_array_to_mlx_contiguous<float>(
nd_array, shape, dtype.value_or(float32));
} else if (type == nb::dtype<double>()) {
return nd_array_to_mlx_contiguous<double>(
nd_array, shape, dtype.value_or(float32));
} else if (type == nb::dtype<std::complex<float>>()) {
return nd_array_to_mlx_contiguous<complex64_t>(
nd_array, shape, dtype.value_or(complex64));
} else if (type == nb::dtype<std::complex<double>>()) {
return nd_array_to_mlx_contiguous<complex128_t>(
nd_array, shape, dtype.value_or(complex64));
} else {
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
}
}
template <typename Lib, typename T>
nb::ndarray<Lib> mlx_to_nd_array(
array a,
std::optional<nb::dlpack::dtype> t = {}) {
// Eval if not already evaled
if (!a.is_evaled()) {
nb::gil_scoped_release nogil;
a.eval();
}
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
return nb::ndarray<Lib>(
a.data<T>(),
a.ndim(),
shape.data(),
nb::handle(),
strides.data(),
t.value_or(nb::dtype<T>()));
}
template <typename Lib>
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
switch (a.dtype()) {
case bool_:
return mlx_to_nd_array<Lib, bool>(a);
case uint8:
return mlx_to_nd_array<Lib, uint8_t>(a);
case uint16:
return mlx_to_nd_array<Lib, uint16_t>(a);
case uint32:
return mlx_to_nd_array<Lib, uint32_t>(a);
case uint64:
return mlx_to_nd_array<Lib, uint64_t>(a);
case int8:
return mlx_to_nd_array<Lib, int8_t>(a);
case int16:
return mlx_to_nd_array<Lib, int16_t>(a);
case int32:
return mlx_to_nd_array<Lib, int32_t>(a);
case int64:
return mlx_to_nd_array<Lib, int64_t>(a);
case float16:
return mlx_to_nd_array<Lib, float16_t>(a);
case bfloat16:
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
case float32:
return mlx_to_nd_array<Lib, float>(a);
case complex64:
return mlx_to_nd_array<Lib, std::complex<float>>(a);
}
}
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
return mlx_to_nd_array<nb::numpy>(a);
}

16
python/src/convert.h Normal file
View File

@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
#include <optional>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include "mlx/array.h"
namespace nb = nanobind;
using namespace mlx::core;
array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype);
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);

View File

@ -1,32 +1,34 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <sstream> #include <sstream>
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
void init_device(py::module_& m) { void init_device(nb::module_& m) {
auto device_class = py::class_<Device>( auto device_class = nb::class_<Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
py::enum_<Device::DeviceType>(m, "DeviceType") nb::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu) .value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu) .value("gpu", Device::DeviceType::gpu)
.export_values() .export_values()
.def( .def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
"__eq__", if (!nb::isinstance<Device>(other) &&
[](const Device::DeviceType& d1, const Device& d2) { !nb::isinstance<Device::DeviceType>(other)) {
return d1 == d2; return false;
}, }
py::prepend()); return d == nb::cast<Device>(other);
});
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0) device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_readonly("type", &Device::type) .def_ro("type", &Device::type)
.def( .def(
"__repr__", "__repr__",
[](const Device& d) { [](const Device& d) {
@ -34,11 +36,15 @@ void init_device(py::module_& m) {
os << d; os << d;
return os.str(); return os.str();
}) })
.def("__eq__", [](const Device& d1, const Device& d2) { .def("__eq__", [](const Device& d, const nb::object& other) {
return d1 == d2; if (!nb::isinstance<Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<Device>(other);
}); });
py::implicitly_convertible<Device::DeviceType, Device>(); nb::implicitly_convertible<Device::DeviceType, Device>();
m.def( m.def(
"default_device", "default_device",

View File

@ -1,20 +1,17 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "mlx/fast.h" #include "mlx/fast.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "python/src/utils.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
void init_extensions(py::module_& parent_module) { void init_fast(nb::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m = auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
@ -31,15 +28,15 @@ void init_extensions(py::module_& parent_module) {
}, },
"a"_a, "a"_a,
"dims"_a, "dims"_a,
py::kw_only(), nb::kw_only(),
"traditional"_a, "traditional"_a,
"base"_a, "base"_a,
"scale"_a, "scale"_a,
"offset"_a, "offset"_a,
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.
Args: Args:
@ -70,20 +67,25 @@ void init_extensions(py::module_& parent_module) {
"q"_a, "q"_a,
"k"_a, "k"_a,
"v"_a, "v"_a,
py::kw_only(), nb::kw_only(),
"scale"_a, "scale"_a,
"mask"_a = none, "mask"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V. Supports:
Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). * [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245)
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. Note: The softmax operation is performed in ``float32`` regardless of
input precision.
Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. and ``v`` inputs should not be pre-tiled to match ``q``.
Args: Args:
q (array): Input query array. q (array): Input query array.
@ -94,6 +96,5 @@ void init_extensions(py::module_& parent_module) {
Returns: Returns:
array: The output array. array: The output array.
)pbdoc"); )pbdoc");
} }

View File

@ -1,19 +1,20 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "python/src/utils.h" #include <nanobind/stl/vector.h>
#include <numeric>
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/ops.h" #include "mlx/ops.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
void init_fft(py::module_& parent_module) { void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms."); "fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def( m.def(
@ -29,9 +30,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"n"_a = none, "n"_a = nb::none(),
"axis"_a = -1, "axis"_a = -1,
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
One dimensional discrete Fourier Transform. One dimensional discrete Fourier Transform.
@ -59,9 +60,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"n"_a = none, "n"_a = nb::none(),
"axis"_a = -1, "axis"_a = -1,
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
One dimensional inverse discrete Fourier Transform. One dimensional inverse discrete Fourier Transform.
@ -95,9 +96,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = std::vector<int>{-2, -1}, "axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Two dimensional discrete Fourier Transform. Two dimensional discrete Fourier Transform.
@ -132,9 +133,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = std::vector<int>{-2, -1}, "axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Two dimensional inverse discrete Fourier Transform. Two dimensional inverse discrete Fourier Transform.
@ -169,9 +170,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = none, "axes"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
n-dimensional discrete Fourier Transform. n-dimensional discrete Fourier Transform.
@ -207,9 +208,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = none, "axes"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
n-dimensional inverse discrete Fourier Transform. n-dimensional inverse discrete Fourier Transform.
@ -239,9 +240,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"n"_a = none, "n"_a = nb::none(),
"axis"_a = -1, "axis"_a = -1,
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
One dimensional discrete Fourier Transform on a real input. One dimensional discrete Fourier Transform on a real input.
@ -274,9 +275,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"n"_a = none, "n"_a = nb::none(),
"axis"_a = -1, "axis"_a = -1,
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
The inverse of :func:`rfft`. The inverse of :func:`rfft`.
@ -314,9 +315,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = std::vector<int>{-2, -1}, "axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Two dimensional real discrete Fourier Transform. Two dimensional real discrete Fourier Transform.
@ -357,9 +358,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = std::vector<int>{-2, -1}, "axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
The inverse of :func:`rfft2`. The inverse of :func:`rfft2`.
@ -400,9 +401,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = none, "axes"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
n-dimensional real discrete Fourier Transform. n-dimensional real discrete Fourier Transform.
@ -443,9 +444,9 @@ void init_fft(py::module_& parent_module) {
} }
}, },
"a"_a, "a"_a,
"s"_a = none, "s"_a = nb::none(),
"axes"_a = none, "axes"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
The inverse of :func:`rfftn`. The inverse of :func:`rfftn`.

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
@ -7,19 +7,19 @@
#include "mlx/ops.h" #include "mlx/ops.h"
bool is_none_slice(const py::slice& in_slice) { bool is_none_slice(const nb::slice& in_slice) {
return ( return (
py::getattr(in_slice, "start").is_none() && nb::getattr(in_slice, "start").is_none() &&
py::getattr(in_slice, "stop").is_none() && nb::getattr(in_slice, "stop").is_none() &&
py::getattr(in_slice, "step").is_none()); nb::getattr(in_slice, "step").is_none());
} }
int get_slice_int(py::object obj, int default_val) { int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) { if (!obj.is_none()) {
if (!py::isinstance<py::int_>(obj)) { if (!nb::isinstance<nb::int_>(obj)) {
throw std::invalid_argument("Slice indices must be integers or None."); throw std::invalid_argument("Slice indices must be integers or None.");
} }
return py::cast<int>(py::cast<py::int_>(obj)); return nb::cast<int>(nb::cast<nb::int_>(obj));
} }
return default_val; return default_val;
} }
@ -28,7 +28,7 @@ void get_slice_params(
int& starts, int& starts,
int& ends, int& ends,
int& strides, int& strides,
const py::slice& in_slice, const nb::slice& in_slice,
int axis_size) { int axis_size) {
// Following numpy's convention // Following numpy's convention
// Assume n is the number of elements in the dimension being sliced. // Assume n is the number of elements in the dimension being sliced.
@ -36,26 +36,26 @@ void get_slice_params(
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for // k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
// k < 0 . If k is not given it defaults to 1 // k < 0 . If k is not given it defaults to 1
strides = get_slice_int(py::getattr(in_slice, "step"), 1); strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
starts = get_slice_int( starts = get_slice_int(
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int( ends = get_slice_int(
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
} }
array get_int_index(py::object idx, int axis_size) { array get_int_index(nb::object idx, int axis_size) {
int idx_ = py::cast<int>(idx); int idx_ = nb::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32); return array(idx_, uint32);
} }
bool is_valid_index_type(const py::object& obj) { bool is_valid_index_type(const nb::object& obj) {
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) || return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj); nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
} }
array mlx_get_item_slice(const array& src, const py::slice& in_slice) { array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
return take(src, indices, 0); return take(src, indices, 0);
} }
array mlx_get_item_int(const array& src, const py::int_& idx) { array mlx_get_item_int(const array& src, const nb::int_& idx) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) {
array mlx_gather_nd( array mlx_gather_nd(
array src, array src,
const std::vector<py::object>& indices, const std::vector<nb::object>& indices,
bool gather_first, bool gather_first,
int& max_dims) { int& max_dims) {
max_dims = 0; max_dims = 0;
@ -117,9 +117,10 @@ array mlx_gather_nd(
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (py::isinstance<py::slice>(idx)) { if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride; int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i)); get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
// Handle negative indices // Handle negative indices
start = (start < 0) ? start + src.shape(i) : start; start = (start < 0) ? start + src.shape(i) : start;
@ -128,10 +129,10 @@ array mlx_gather_nd(
gather_indices.push_back(arange(start, end, stride, uint32)); gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++; num_slices++;
is_slice[i] = true; is_slice[i] = true;
} else if (py::isinstance<py::int_>(idx)) { } else if (nb::isinstance<nb::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i))); gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (py::isinstance<array>(idx)) { } else if (nb::isinstance<array>(idx)) {
auto arr = py::cast<array>(idx); auto arr = nb::cast<array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims); max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr); gather_indices.push_back(arr);
} }
@ -185,7 +186,7 @@ array mlx_gather_nd(
return src; return src;
} }
array mlx_get_item_nd(array src, const py::tuple& entries) { array mlx_get_item_nd(array src, const nb::tuple& entries) {
// No indices make this a noop // No indices make this a noop
if (entries.size() == 0) { if (entries.size() == 0) {
return src; return src;
@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// 3. Calculate the remaining slices and reshapes // 3. Calculate the remaining slices and reshapes
// Ellipsis handling // Ellipsis handling
std::vector<py::object> indices; std::vector<nb::object> indices;
{ {
int non_none_indices_before = 0; int non_none_indices_before = 0;
int non_none_indices_after = 0; int non_none_indices_after = 0;
std::vector<py::object> r_indices; std::vector<nb::object> r_indices;
int i = 0; int i = 0;
for (; i < entries.size(); i++) { for (; i < entries.size(); i++) {
auto idx = entries[i]; auto idx = entries[i];
@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument( throw std::invalid_argument(
"Cannot index mlx array using the given type yet"); "Cannot index mlx array using the given type yet");
} }
if (!py::ellipsis().is(idx)) { if (!nb::ellipsis().is(idx)) {
indices.push_back(idx); indices.push_back(idx);
non_none_indices_before += !idx.is_none(); non_none_indices_before += !idx.is_none();
} else { } else {
@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument( throw std::invalid_argument(
"Cannot index mlx array using the given type yet"); "Cannot index mlx array using the given type yet");
} }
if (py::ellipsis().is(idx)) { if (nb::ellipsis().is(idx)) {
throw std::invalid_argument( throw std::invalid_argument(
"An index can only have a single ellipsis (...)"); "An index can only have a single ellipsis (...)");
} }
@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (int axis = non_none_indices_before; for (int axis = non_none_indices_before;
axis < src.ndim() - non_none_indices_after; axis < src.ndim() - non_none_indices_after;
axis++) { axis++) {
indices.push_back(py::slice(0, src.shape(axis), 1)); indices.push_back(nb::slice(0, src.shape(axis), 1));
} }
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
} }
@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// //
// Check whether we have arrays or integer indices and delegate to gather_nd // Check whether we have arrays or integer indices and delegate to gather_nd
// after removing the slices at the end and all Nones. // after removing the slices at the end and all Nones.
std::vector<py::object> remaining_indices; std::vector<nb::object> remaining_indices;
bool have_array = false; bool have_array = false;
{ {
// First check whether the results of gather are going to be 1st or // First check whether the results of gather are going to be 1st or
@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
bool have_non_array = false; bool have_non_array = false;
bool gather_first = false; bool gather_first = false;
for (auto& idx : indices) { for (auto& idx : indices) {
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (have_array && have_non_array) { if (have_array && have_non_array) {
gather_first = true; gather_first = true;
break; break;
@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// Then find the last array // Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) { for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array]; auto& idx = indices[last_array];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
} }
} }
std::vector<py::object> gather_indices; std::vector<nb::object> gather_indices;
for (int i = 0; i <= last_array; i++) { for (int i = 0; i <= last_array; i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (!idx.is_none()) { if (!idx.is_none()) {
@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
if (gather_first) { if (gather_first) {
for (int i = 0; i < max_dims; i++) { for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back( remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
for (int i = 0; i < last_array; i++) { for (int i = 0; i < last_array; i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (idx.is_none()) { if (idx.is_none()) {
remaining_indices.push_back(indices[i]); remaining_indices.push_back(indices[i]);
} else if (py::isinstance<py::slice>(idx)) { } else if (nb::isinstance<nb::slice>(idx)) {
remaining_indices.push_back( remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
} }
for (int i = last_array + 1; i < indices.size(); i++) { for (int i = last_array + 1; i < indices.size(); i++) {
@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
} else { } else {
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) { if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
} else if (idx.is_none()) { } else if (idx.is_none()) {
remaining_indices.push_back(idx); remaining_indices.push_back(idx);
} else { } else {
remaining_indices.push_back( remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
} }
for (int i = 0; i < max_dims; i++) { for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back( remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
for (int i = last_array + 1; i < indices.size(); i++) { for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]); remaining_indices.push_back(indices[i]);
@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (auto& idx : remaining_indices) { for (auto& idx : remaining_indices) {
if (!idx.is_none()) { if (!idx.is_none()) {
get_slice_params( get_slice_params(
starts[axis], ends[axis], strides[axis], idx, ends[axis]); starts[axis],
ends[axis],
strides[axis],
nb::cast<nb::slice>(idx),
ends[axis]);
axis++; axis++;
} }
} }
@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
return src; return src;
} }
array mlx_get_item(const array& src, const py::object& obj) { array mlx_get_item(const array& src, const nb::object& obj) {
if (py::isinstance<py::slice>(obj)) { if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, obj); return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (py::isinstance<array>(obj)) { } else if (nb::isinstance<array>(obj)) {
return mlx_get_item_array(src, py::cast<array>(obj)); return mlx_get_item_array(src, nb::cast<array>(obj));
} else if (py::isinstance<py::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, obj); return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (py::isinstance<py::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, obj); return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) { } else if (obj.is_none()) {
std::vector<int> s(1, 1); std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end()); s.insert(s.end(), src.shape().begin(), src.shape().end());
@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int( std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src, const array& src,
const py::int_& idx, const nb::int_& idx,
const array& update) { const array& update) {
if (src.ndim() == 0) { if (src.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -446,7 +453,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice( std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src, const array& src,
const py::slice& in_slice, const nb::slice& in_slice,
const array& update) { const array& update) {
// Check input and raise error if 0 dim for parity with np // Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) { if (src.ndim() == 0) {
@ -478,9 +485,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd( std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src, const array& src,
const py::tuple& entries, const nb::tuple& entries,
const array& update) { const array& update) {
std::vector<py::object> indices; std::vector<nb::object> indices;
int non_none_indices = 0; int non_none_indices = 0;
// Expand ellipses into a series of ':' slices // Expand ellipses into a series of ':' slices
@ -494,7 +501,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
if (!is_valid_index_type(idx)) { if (!is_valid_index_type(idx)) {
throw std::invalid_argument( throw std::invalid_argument(
"Cannot index mlx array using the given type yet"); "Cannot index mlx array using the given type yet");
} else if (!py::ellipsis().is(idx)) { } else if (!nb::ellipsis().is(idx)) {
if (!has_ellipsis) { if (!has_ellipsis) {
indices_before++; indices_before++;
non_none_indices_before += !idx.is_none(); non_none_indices_before += !idx.is_none();
@ -514,7 +521,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
axis < src.ndim() - non_none_indices_after; axis < src.ndim() - non_none_indices_after;
axis++) { axis++) {
indices.insert( indices.insert(
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1)); indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
} }
non_none_indices = src.ndim(); non_none_indices = src.ndim();
} else { } else {
@ -549,15 +556,15 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
bool have_array = false; bool have_array = false;
bool have_non_array = false; bool have_non_array = false;
for (auto& idx : indices) { for (auto& idx : indices) {
if (py::isinstance<py::slice>(idx) || idx.is_none()) { if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
have_non_array = have_array; have_non_array = have_array;
num_slices++; num_slices++;
} else if (py::isinstance<array>(idx)) { } else if (nb::isinstance<array>(idx)) {
have_array = true; have_array = true;
if (have_array && have_non_array) { if (have_array && have_non_array) {
arrays_first = true; arrays_first = true;
} }
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim); max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
num_arrays++; num_arrays++;
} }
} }
@ -569,10 +576,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
int ax = 0; int ax = 0;
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i]; auto& pyidx = indices[i];
if (py::isinstance<py::slice>(pyidx)) { if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride; int start, end, stride;
auto axis_size = src.shape(ax++); auto axis_size = src.shape(ax++);
get_slice_params(start, end, stride, pyidx, axis_size); get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
// Handle negative indices // Handle negative indices
start = (start < 0) ? start + axis_size : start; start = (start < 0) ? start + axis_size : start;
@ -584,13 +592,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
slice_num++; slice_num++;
idx_shape[loc] = idx.size(); idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape)); arr_indices.push_back(reshape(idx, idx_shape));
} else if (py::isinstance<py::int_>(pyidx)) { } else if (nb::isinstance<nb::int_>(pyidx)) {
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
} else if (pyidx.is_none()) { } else if (pyidx.is_none()) {
slice_num++; slice_num++;
} else if (py::isinstance<array>(pyidx)) { } else if (nb::isinstance<array>(pyidx)) {
ax++; ax++;
auto idx = py::cast<array>(pyidx); auto idx = nb::cast<array>(pyidx);
std::vector<int> idx_shape; std::vector<int> idx_shape;
if (!arrays_first) { if (!arrays_first) {
idx_shape.insert(idx_shape.end(), slice_num, 1); idx_shape.insert(idx_shape.end(), slice_num, 1);
@ -629,24 +637,24 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
std::tuple<std::vector<array>, array, std::vector<int>> std::tuple<std::vector<array>, array, std::vector<int>>
mlx_compute_scatter_args( mlx_compute_scatter_args(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype()); auto vals = to_array(v, src.dtype());
if (py::isinstance<py::slice>(obj)) { if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, obj, vals); return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (py::isinstance<array>(obj)) { } else if (nb::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, py::cast<array>(obj), vals); return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, obj, vals); return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (py::isinstance<py::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, obj, vals); return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) { } else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}}; return {{}, broadcast_to(vals, src.shape()), {}};
} }
throw std::invalid_argument("Cannot index mlx array using the given type."); throw std::invalid_argument("Cannot index mlx array using the given type.");
} }
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes); auto out = scatter(src, indices, updates, axes);
@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
array mlx_add_item( array mlx_add_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
@ -670,7 +678,7 @@ array mlx_add_item(
array mlx_subtract_item( array mlx_subtract_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
@ -682,7 +690,7 @@ array mlx_subtract_item(
array mlx_multiply_item( array mlx_multiply_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
@ -694,7 +702,7 @@ array mlx_multiply_item(
array mlx_divide_item( array mlx_divide_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
@ -706,7 +714,7 @@ array mlx_divide_item(
array mlx_maximum_item( array mlx_maximum_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {
@ -718,7 +726,7 @@ array mlx_maximum_item(
array mlx_minimum_item( array mlx_minimum_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) { if (indices.size() > 0) {

View File

@ -1,38 +1,38 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include "mlx/array.h" #include "mlx/array.h"
#include "python/src/utils.h" #include "python/src/utils.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace mlx::core; using namespace mlx::core;
array mlx_get_item(const array& src, const py::object& obj); array mlx_get_item(const array& src, const nb::object& obj);
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
array mlx_add_item( array mlx_add_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_subtract_item( array mlx_subtract_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_multiply_item( array mlx_multiply_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_divide_item( array mlx_divide_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_maximum_item( array mlx_maximum_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);
array mlx_minimum_item( array mlx_minimum_item(
const array& src, const array& src,
const py::object& obj, const nb::object& obj,
const ScalarOrArray& v); const ScalarOrArray& v);

View File

@ -1,32 +1,29 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <variant> #include <variant>
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "python/src/load.h" namespace nb = nanobind;
#include "python/src/utils.h" using namespace nb::literals;
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core; using namespace mlx::core;
using namespace mlx::core::linalg; using namespace mlx::core::linalg;
namespace { namespace {
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
const auto result = svd(a, s); const auto result = svd(a, s);
return py::make_tuple(result.at(0), result.at(1), result.at(2)); return nb::make_tuple(result.at(0), result.at(1), result.at(2));
} }
} // namespace } // namespace
void init_linalg(py::module_& parent_module) { void init_linalg(nb::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines."); "linalg", "mlx.core.linalg: linear algebra routines.");
@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) {
return norm(a, ord, axis, keepdims, stream); return norm(a, ord, axis, keepdims, stream);
} }
}, },
"a"_a, nb::arg(),
py::pos_only(), "ord"_a = nb::none(),
"ord"_a = none, "axis"_a = nb::none(),
"axis"_a = none,
"keepdims"_a = false, "keepdims"_a = false,
py::kw_only(), nb::kw_only(),
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm. Matrix or vector norm.
This function computes vector or matrix norms depending on the value of This function computes vector or matrix norms depending on the value of
@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) {
"qr", "qr",
&qr, &qr,
"a"_a, "a"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
R"pbdoc( R"pbdoc(
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
The QR factorization of the input matrix. The QR factorization of the input matrix.
This function supports arrays with at least 2 dimensions. The matrices This function supports arrays with at least 2 dimensions. The matrices
@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) {
"svd", "svd",
&svd_helper, &svd_helper,
"a"_a, "a"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
R"pbdoc( R"pbdoc(
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
The Singular Value Decomposition (SVD) of the input matrix. The Singular Value Decomposition (SVD) of the input matrix.
This function supports arrays with at least 2 dimensions. When the input This function supports arrays with at least 2 dimensions. When the input
@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) {
"inv", "inv",
&inv, &inv,
"a"_a, "a"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = none, "stream"_a = nb::none(),
nb::sig(
"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array
Compute the inverse of a square matrix. Compute the inverse of a square matrix.
This function supports arrays with at least 2 dimensions. When the input This function supports arrays with at least 2 dimensions. When the input

View File

@ -1,8 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/stl/vector.h>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <stdexcept> #include <stdexcept>
@ -16,39 +14,39 @@
#include "python/src/load.h" #include "python/src/load.h"
#include "python/src/utils.h" #include "python/src/utils.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Helpers // Helpers
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
bool is_istream_object(const py::object& file) { bool is_istream_object(const nb::object& file) {
return py::hasattr(file, "readinto") && py::hasattr(file, "seek") && return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed"); nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
} }
bool is_ostream_object(const py::object& file) { bool is_ostream_object(const nb::object& file) {
return py::hasattr(file, "write") && py::hasattr(file, "seek") && return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed"); nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
} }
bool is_zip_file(const py::module_& zipfile, const py::object& file) { bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
if (is_istream_object(file)) { if (is_istream_object(file)) {
auto st_pos = file.attr("tell")(); auto st_pos = file.attr("tell")();
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>(); bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
file.attr("seek")(st_pos, 0); file.attr("seek")(st_pos, 0);
return r; return r;
} }
return zipfile.attr("is_zipfile")(file).cast<bool>(); return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
} }
class ZipFileWrapper { class ZipFileWrapper {
public: public:
ZipFileWrapper( ZipFileWrapper(
const py::module_& zipfile, const nb::module_& zipfile,
const py::object& file, const nb::object& file,
char mode = 'r', char mode = 'r',
int compression = 0) int compression = 0)
: zipfile_module_(zipfile), : zipfile_module_(zipfile),
@ -63,10 +61,10 @@ class ZipFileWrapper {
close_func_(zipfile_object_.attr("close")) {} close_func_(zipfile_object_.attr("close")) {}
std::vector<std::string> namelist() const { std::vector<std::string> namelist() const {
return files_list_.cast<std::vector<std::string>>(); return nb::cast<std::vector<std::string>>(files_list_);
} }
py::object open(const std::string& key, char mode = 'r') { nb::object open(const std::string& key, char mode = 'r') {
// Following numpy : // Following numpy :
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47 // https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
if (mode == 'w') { if (mode == 'w') {
@ -76,12 +74,12 @@ class ZipFileWrapper {
} }
private: private:
py::module_ zipfile_module_; nb::module_ zipfile_module_;
py::object zipfile_object_; nb::object zipfile_object_;
py::list files_list_; nb::list files_list_;
py::object open_func_; nb::object open_func_;
py::object read_func_; nb::object read_func_;
py::object close_func_; nb::object close_func_;
}; };
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -90,14 +88,14 @@ class ZipFileWrapper {
class PyFileReader : public io::Reader { class PyFileReader : public io::Reader {
public: public:
PyFileReader(py::object file) PyFileReader(nb::object file)
: pyistream_(file), : pyistream_(file),
readinto_func_(file.attr("readinto")), readinto_func_(file.attr("readinto")),
seek_func_(file.attr("seek")), seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {} tell_func_(file.attr("tell")) {}
~PyFileReader() { ~PyFileReader() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
pyistream_.release().dec_ref(); pyistream_.release().dec_ref();
readinto_func_.release().dec_ref(); readinto_func_.release().dec_ref();
@ -108,8 +106,8 @@ class PyFileReader : public io::Reader {
bool is_open() const override { bool is_open() const override {
bool out; bool out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = !pyistream_.attr("closed").cast<bool>(); out = !nb::cast<bool>(pyistream_.attr("closed"));
} }
return out; return out;
} }
@ -117,7 +115,7 @@ class PyFileReader : public io::Reader {
bool good() const override { bool good() const override {
bool out; bool out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = !pyistream_.is_none(); out = !pyistream_.is_none();
} }
return out; return out;
@ -126,25 +124,24 @@ class PyFileReader : public io::Reader {
size_t tell() const override { size_t tell() const override {
size_t out; size_t out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>(); out = nb::cast<size_t>(tell_func_());
} }
return out; return out;
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
seek_func_(off, (int)way); seek_func_(off, (int)way);
} }
void read(char* data, size_t n) override { void read(char* data, size_t n) override {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
py::object bytes_read = if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream"); throw std::runtime_error("[load] Failed to read from python stream");
} }
} }
@ -154,23 +151,23 @@ class PyFileReader : public io::Reader {
} }
private: private:
py::object pyistream_; nb::object pyistream_;
py::object readinto_func_; nb::object readinto_func_;
py::object seek_func_; nb::object seek_func_;
py::object tell_func_; nb::object tell_func_;
}; };
std::pair< std::pair<
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>> std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) { mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(py::cast<std::string>(file), s); return load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s); auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
{ {
py::gil_scoped_release gil; nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) { for (auto& [key, arr] : std::get<0>(res)) {
arr.eval(); arr.eval();
} }
@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string"); "[load_safetensors] Input must be a file-like object, or string");
} }
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) { GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s); return load_gguf(nb::cast<std::string>(file), s);
} }
throw std::invalid_argument("[load_gguf] Input must be a string"); throw std::invalid_argument("[load_gguf] Input must be a string");
} }
std::unordered_map<std::string, array> mlx_load_npz_helper( std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file, nb::object file,
StreamOrDevice s) { StreamOrDevice s) {
bool own_file = py::isinstance<py::str>(file); bool own_file = nb::isinstance<nb::str>(file);
py::module_ zipfile = py::module_::import("zipfile"); nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) { if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument( throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be " "[load_npz] Input must be a zip file or a file-like object that can be "
@ -208,7 +205,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
ZipFileWrapper zipfile_object(zipfile, file); ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) { for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream // Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st); nb::object sub_file = zipfile_object.open(st);
// Create array from python fille stream // Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s); auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
@ -224,7 +221,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
if (!own_file) { if (!own_file) {
py::gil_scoped_release gil; nb::gil_scoped_release gil;
for (auto& [key, arr] : array_dict) { for (auto& [key, arr] : array_dict) {
arr.eval(); arr.eval();
} }
@ -233,14 +230,14 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict; return array_dict;
} }
array mlx_load_npy_helper(py::object file, StreamOrDevice s) { array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(py::cast<std::string>(file), s); return load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s); auto arr = load(std::make_shared<PyFileReader>(file), s);
{ {
py::gil_scoped_release gil; nb::gil_scoped_release gil;
arr.eval(); arr.eval();
} }
return arr; return arr;
@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
} }
LoadOutputTypes mlx_load_helper( LoadOutputTypes mlx_load_helper(
py::object file, nb::object file,
std::optional<std::string> format, std::optional<std::string> format,
bool return_metadata, bool return_metadata,
StreamOrDevice s) { StreamOrDevice s) {
if (!format.has_value()) { if (!format.has_value()) {
std::string fname; std::string fname;
if (py::isinstance<py::str>(file)) { if (nb::isinstance<nb::str>(file)) {
fname = py::cast<std::string>(file); fname = nb::cast<std::string>(file);
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>(); fname = nb::cast<std::string>(file.attr("name"));
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[load] Input must be a file-like object opened in binary mode, or string"); "[load] Input must be a file-like object opened in binary mode, or string");
@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper(
class PyFileWriter : public io::Writer { class PyFileWriter : public io::Writer {
public: public:
PyFileWriter(py::object file) PyFileWriter(nb::object file)
: pyostream_(file), : pyostream_(file),
write_func_(file.attr("write")), write_func_(file.attr("write")),
seek_func_(file.attr("seek")), seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {} tell_func_(file.attr("tell")) {}
~PyFileWriter() { ~PyFileWriter() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
pyostream_.release().dec_ref(); pyostream_.release().dec_ref();
write_func_.release().dec_ref(); write_func_.release().dec_ref();
@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer {
bool is_open() const override { bool is_open() const override {
bool out; bool out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = !pyostream_.attr("closed").cast<bool>(); out = !nb::cast<bool>(pyostream_.attr("closed"));
} }
return out; return out;
} }
@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer {
bool good() const override { bool good() const override {
bool out; bool out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = !pyostream_.is_none(); out = !pyostream_.is_none();
} }
return out; return out;
@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer {
size_t tell() const override { size_t tell() const override {
size_t out; size_t out;
{ {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>(); out = nb::cast<size_t>(tell_func_());
} }
return out; return out;
} }
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override { override {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
seek_func_(off, (int)way); seek_func_(off, (int)way);
} }
void write(const char* data, size_t n) override { void write(const char* data, size_t n) override {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
py::object bytes_written = auto memview =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
nb::object bytes_written = write_func_(nb::handle(memview));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) { if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream"); throw std::runtime_error("[load] Failed to write to python stream");
} }
} }
@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer {
} }
private: private:
py::object pyostream_; nb::object pyostream_;
py::object write_func_; nb::object write_func_;
py::object seek_func_; nb::object seek_func_;
py::object tell_func_; nb::object tell_func_;
}; };
void mlx_save_helper(py::object file, array a) { void mlx_save_helper(nb::object file, array a) {
if (py::isinstance<py::str>(file)) { if (nb::isinstance<nb::str>(file)) {
save(py::cast<std::string>(file), a); save(nb::cast<std::string>(file), a);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; nb::gil_scoped_release gil;
save(writer, a); save(writer, a);
} }
@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) {
} }
void mlx_savez_helper( void mlx_savez_helper(
py::object file_, nb::object file_,
py::args args, nb::args args,
const py::kwargs& kwargs, const nb::kwargs& kwargs,
bool compressed) { bool compressed) {
// Add .npz to the end of the filename if not already there // Add .npz to the end of the filename if not already there
py::object file = file_; nb::object file = file_;
if (py::isinstance<py::str>(file_)) { if (nb::isinstance<nb::str>(file_)) {
std::string fname = file_.cast<std::string>(); std::string fname = nb::cast<std::string>(file_);
// Add .npz to file name if it is not there // Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
fname += ".npz"; fname += ".npz";
file = py::str(fname); file = nb::cast(fname);
} }
// Collect args and kwargs // Collect args and kwargs
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>(); auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = args.cast<std::vector<array>>(); auto arrays_list = nb::cast<std::vector<array>>(args);
for (int i = 0; i < arrays_list.size(); i++) { for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i); std::string arr_name = "arr_" + std::to_string(i);
@ -426,9 +424,9 @@ void mlx_savez_helper(
} }
// Create python ZipFile object depending on compression // Create python ZipFile object depending on compression
py::module_ zipfile = py::module_::import("zipfile"); nb::module_ zipfile = nb::module_::import_("zipfile");
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>() int compression = nb::cast<int>(
: zipfile.attr("ZIP_STORED").cast<int>(); compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
char mode = 'w'; char mode = 'w';
ZipFileWrapper zipfile_object(zipfile, file, mode, compression); ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
@ -438,7 +436,7 @@ void mlx_savez_helper(
auto py_ostream = zipfile_object.open(fname, 'w'); auto py_ostream = zipfile_object.open(fname, 'w');
auto writer = std::make_shared<PyFileWriter>(py_ostream); auto writer = std::make_shared<PyFileWriter>(py_ostream);
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save(writer, a); save(writer, a);
} }
} }
@ -447,31 +445,31 @@ void mlx_savez_helper(
} }
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(
py::object file, nb::object file,
py::dict d, nb::dict d,
std::optional<py::dict> m) { std::optional<nb::dict> m) {
std::unordered_map<std::string, std::string> metadata_map; std::unordered_map<std::string, std::string> metadata_map;
if (m) { if (m) {
try { try {
metadata_map = metadata_map =
m.value().cast<std::unordered_map<std::string, std::string>>(); nb::cast<std::unordered_map<std::string, std::string>>(m.value());
} catch (const py::cast_error& e) { } catch (const nb::cast_error& e) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensors] Metadata must be a dictionary with string keys and values"); "[save_safetensors] Metadata must be a dictionary with string keys and values");
} }
} else { } else {
metadata_map = std::unordered_map<std::string, std::string>(); metadata_map = std::unordered_map<std::string, std::string>();
} }
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
if (py::isinstance<py::str>(file)) { if (nb::isinstance<nb::str>(file)) {
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map); save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
} }
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map); save_safetensors(writer, arrays_map, metadata_map);
} }
} else { } else {
@ -481,22 +479,22 @@ void mlx_save_safetensor_helper(
} }
void mlx_save_gguf_helper( void mlx_save_gguf_helper(
py::object file, nb::object file,
py::dict a, nb::dict a,
std::optional<py::dict> m) { std::optional<nb::dict> m) {
auto arrays_map = a.cast<std::unordered_map<std::string, array>>(); auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
if (py::isinstance<py::str>(file)) { if (nb::isinstance<nb::str>(file)) {
if (m) { if (m) {
auto metadata_map = auto metadata_map =
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>(); nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map); save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
} }
} else { } else {
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map); save_gguf(nb::cast<std::string>(file), arrays_map);
} }
} }
} else { } else {

View File

@ -1,15 +1,20 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <optional> #include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include "mlx/io.h" #include "mlx/io.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace mlx::core; using namespace mlx::core;
using LoadOutputTypes = std::variant< using LoadOutputTypes = std::variant<
@ -18,27 +23,27 @@ using LoadOutputTypes = std::variant<
SafetensorsLoad, SafetensorsLoad,
GGUFLoad>; GGUFLoad>;
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s); SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
void mlx_save_safetensor_helper( void mlx_save_safetensor_helper(
py::object file, nb::object file,
py::dict d, nb::dict d,
std::optional<py::dict> m); std::optional<nb::dict> m);
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s); GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
void mlx_save_gguf_helper( void mlx_save_gguf_helper(
py::object file, nb::object file,
py::dict d, nb::dict d,
std::optional<py::dict> m); std::optional<nb::dict> m);
LoadOutputTypes mlx_load_helper( LoadOutputTypes mlx_load_helper(
py::object file, nb::object file,
std::optional<std::string> format, std::optional<std::string> format,
bool return_metadata, bool return_metadata,
StreamOrDevice s); StreamOrDevice s);
void mlx_save_helper(py::object file, array a); void mlx_save_helper(nb::object file, array a);
void mlx_savez_helper( void mlx_savez_helper(
py::object file, nb::object file,
py::args args, nb::args args,
const py::kwargs& kwargs, const nb::kwargs& kwargs,
bool compressed = false); bool compressed = false);

View File

@ -1,16 +1,15 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include <nanobind/nanobind.h>
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
void init_metal(py::module_& m) { void init_metal(nb::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal"); nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def( metal.def(
"is_available", "is_available",
&metal::is_available, &metal::is_available,
@ -48,7 +47,7 @@ void init_metal(py::module_& m) {
"set_memory_limit", "set_memory_limit",
&metal::set_memory_limit, &metal::set_memory_limit,
"limit"_a, "limit"_a,
py::kw_only(), nb::kw_only(),
"relaxed"_a = true, "relaxed"_a = true,
R"pbdoc( R"pbdoc(
Set the memory limit. Set the memory limit.

View File

@ -1,30 +1,30 @@
// Copyright © 2023 Apple Inc. // Conbright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x) #define TOSTRING(x) STRINGIFY(x)
namespace py = pybind11; namespace nb = nanobind;
void init_array(py::module_&); void init_array(nb::module_&);
void init_device(py::module_&); void init_device(nb::module_&);
void init_stream(py::module_&); void init_stream(nb::module_&);
void init_metal(py::module_&); void init_metal(nb::module_&);
void init_ops(py::module_&); void init_ops(nb::module_&);
void init_transforms(py::module_&); void init_transforms(nb::module_&);
void init_random(py::module_&); void init_random(nb::module_&);
void init_fft(py::module_&); void init_fft(nb::module_&);
void init_linalg(py::module_&); void init_linalg(nb::module_&);
void init_constants(py::module_&); void init_constants(nb::module_&);
void init_extensions(py::module_&); void init_fast(nb::module_&);
void init_utils(py::module_&);
PYBIND11_MODULE(core, m) { NB_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon."; m.doc() = "mlx: A framework for machine learning on Apple silicon.";
auto reprlib_fix = py::module_::import("mlx._reprlib_fix"); auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
py::module_::import("mlx._os_warning"); nb::module_::import_("mlx._os_warning");
nb::set_leak_warnings(false);
init_device(m); init_device(m);
init_stream(m); init_stream(m);
@ -36,8 +36,7 @@ PYBIND11_MODULE(core, m) {
init_fft(m); init_fft(m);
init_linalg(m); init_linalg(m);
init_constants(m); init_constants(m);
init_extensions(m); init_fast(m);
init_utils(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = TOSTRING(_VERSION_);
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,60 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// A patch to get float16_t to work with pybind11 numpy arrays
// Derived from:
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
#include <pybind11/numpy.h>
namespace pybind11::detail {
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
// Kinda following:
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
template <>
struct npy_format_descriptor<float16_t> {
static constexpr auto name = _("float16");
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
};
template <>
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
static constexpr auto name = _("float16");
};
} // namespace pybind11::detail

View File

@ -1,7 +1,10 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono> #include <chrono>
#include "python/src/utils.h" #include "python/src/utils.h"
@ -9,8 +12,8 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/random.h" #include "mlx/random.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
using namespace mlx::core::random; using namespace mlx::core::random;
@ -25,22 +28,22 @@ class PyKeySequence {
} }
array next() { array next() {
auto out = split(py::cast<array>(state_[0])); auto out = split(nb::cast<array>(state_[0]));
state_[0] = out.first; state_[0] = out.first;
return out.second; return out.second;
} }
py::list state() { nb::list state() {
return state_; return state_;
} }
void release() { void release() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
state_.release().dec_ref(); state_.release().dec_ref();
} }
private: private:
py::list state_; nb::list state_;
}; };
PyKeySequence& default_key() { PyKeySequence& default_key() {
@ -54,7 +57,7 @@ PyKeySequence& default_key() {
return ks; return ks;
} }
void init_random(py::module_& parent_module) { void init_random(nb::module_& parent_module) {
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"random", "random",
"mlx.core.random: functionality related to random number generation"); "mlx.core.random: functionality related to random number generation");
@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"split", "split",
py::overload_cast<const array&, int, StreamOrDevice>(&random::split), nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
"key"_a, "key"_a,
"num"_a = 2, "num"_a = 2,
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Split a PRNG key into sub keys. Split a PRNG key into sub keys.
@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) {
"low"_a = 0, "low"_a = 0,
"high"_a = 1, "high"_a = 1,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32}, "dtype"_a.none() = float32,
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Generate uniformly distributed random numbers. Generate uniformly distributed random numbers.
@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) {
return normal(shape, type.value_or(float32), loc, scale, key, s); return normal(shape, type.value_or(float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32}, "dtype"_a.none() = float32,
"loc"_a = 0.0, "loc"_a = 0.0,
"scale"_a = 1.0, "scale"_a = 1.0,
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Generate normally distributed random numbers. Generate normally distributed random numbers.
@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) {
"low"_a, "low"_a,
"high"_a, "high"_a,
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = int32, "dtype"_a.none() = int32,
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Generate random integers from the given interval. Generate random integers from the given interval.
@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) {
} }
}, },
"p"_a = 0.5, "p"_a = 0.5,
"shape"_a = none, "shape"_a = nb::none(),
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Generate Bernoulli random values. Generate Bernoulli random values.
@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) {
}, },
"lower"_a, "lower"_a,
"upper"_a, "upper"_a,
"shape"_a = none, "shape"_a = nb::none(),
"dtype"_a = std::optional{float32}, "dtype"_a.none() = float32,
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Generate values from a truncated normal distribution. Generate values from a truncated normal distribution.
@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) {
return gumbel(shape, type.value_or(float32), key, s); return gumbel(shape, type.value_or(float32), key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32}, "dtype"_a.none() = float32,
"stream"_a = none, "stream"_a = nb::none(),
"key"_a = none, "key"_a = nb::none(),
R"pbdoc( R"pbdoc(
Sample from the standard Gumbel distribution. Sample from the standard Gumbel distribution.
@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) {
}, },
"logits"_a, "logits"_a,
"axis"_a = -1, "axis"_a = -1,
"shape"_a = none, "shape"_a = nb::none(),
"num_samples"_a = none, "num_samples"_a = nb::none(),
"key"_a = none, "key"_a = nb::none(),
"stream"_a = none, "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
Sample from a categorical distribution. Sample from a categorical distribution.
@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) {
array: The ``shape``-sized output array with type ``uint32``. array: The ``shape``-sized output array with type ``uint32``.
)pbdoc"); )pbdoc");
// Register static Python object cleanup before the interpreter exits // Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit"); auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
} }

View File

@ -1,25 +1,54 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <sstream> #include <sstream>
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "mlx/stream.h" #include "mlx/stream.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
void init_stream(py::module_& m) { // Create the StreamContext on enter and delete on exit.
py::class_<Stream>( class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_stream(nb::module_& m) {
nb::class_<Stream>(
m, m,
"Stream", "Stream",
R"pbdoc( R"pbdoc(
A stream for running operations on a given device. A stream for running operations on a given device.
)pbdoc") )pbdoc")
.def(py::init<int, Device>(), "index"_a, "device"_a) .def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_readonly("device", &Stream::device) .def_ro("device", &Stream::device)
.def( .def(
"__repr__", "__repr__",
[](const Stream& s) { [](const Stream& s) {
@ -31,7 +60,7 @@ void init_stream(py::module_& m) {
return s1 == s2; return s1 == s2;
}); });
py::implicitly_convertible<Device::DeviceType, Device>(); nb::implicitly_convertible<Device::DeviceType, Device>();
m.def( m.def(
"default_stream", "default_stream",
@ -56,4 +85,48 @@ void init_stream(py::module_& m) {
&new_stream, &new_stream,
"device"_a, "device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc"); R"pbdoc(Make a new stream on the given device.)pbdoc");
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(nb::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type,
const std::optional<nb::object>& exc_value,
const std::optional<nb::object>& traceback) { scm.exit(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none());
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
} }

View File

@ -1,6 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
@ -13,13 +18,17 @@
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
#include "python/src/trees.h" #include "python/src/trees.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
using IntOrVec = std::variant<int, std::vector<int>>; using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>; using StrOrVec = std::variant<std::string, std::vector<std::string>>;
inline std::string type_name_str(const nb::handle& o) {
return nb::cast<std::string>(nb::type_name(o.type()));
}
template <typename T> template <typename T>
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) { std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
std::vector<T> vals; std::vector<T> vals;
@ -49,7 +58,7 @@ auto validate_argnums_argnames(
} }
auto py_value_and_grad( auto py_value_and_grad(
const py::function& fun, const nb::callable& fun,
std::vector<int> argnums, std::vector<int> argnums,
std::vector<std::string> argnames, std::vector<std::string> argnames,
const std::string& error_msg_tag, const std::string& error_msg_tag,
@ -71,7 +80,7 @@ auto py_value_and_grad(
} }
return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
const py::args& args, const py::kwargs& kwargs) { const nb::args& args, const nb::kwargs& kwargs) {
// Sanitize the input // Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) { if (argnums.size() > 0 && argnums.back() >= args.size()) {
std::ostringstream msg; std::ostringstream msg;
@ -89,7 +98,7 @@ auto py_value_and_grad(
<< "' because the function is called with the " << "' because the function is called with the "
<< "following keyword arguments {"; << "following keyword arguments {";
for (auto item : kwargs) { for (auto item : kwargs) {
msg << item.first.cast<std::string>() << ","; msg << nb::cast<std::string>(item.first) << ",";
} }
msg << "}"; msg << "}";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
@ -115,7 +124,7 @@ auto py_value_and_grad(
// value_out will hold the output of the python function in order to be // value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values // able to reconstruct the python tree of extra return values
py::object py_value_out; nb::object py_value_out;
auto value_and_grads = value_and_grad( auto value_and_grads = value_and_grad(
[&fun, [&fun,
&args, &args,
@ -127,15 +136,15 @@ auto py_value_and_grad(
&error_msg_tag, &error_msg_tag,
scalar_func_only](const std::vector<array>& a) { scalar_func_only](const std::vector<array>& a) {
// Copy the arguments // Copy the arguments
py::args args_cpy = py::tuple(args.size()); nb::list args_cpy;
py::kwargs kwargs_cpy = py::kwargs(); nb::kwargs kwargs_cpy = nb::kwargs();
int j = 0; int j = 0;
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
if (j < argnums.size() && i == argnums[j]) { if (j < argnums.size() && i == argnums[j]) {
args_cpy[i] = tree_unflatten(args[i], a, counts[j]); args_cpy.append(tree_unflatten(args[i], a, counts[j]));
j++; j++;
} else { } else {
args_cpy[i] = args[i]; args_cpy.append(args[i]);
} }
} }
for (auto& key : argnames) { for (auto& key : argnames) {
@ -154,25 +163,25 @@ auto py_value_and_grad(
py_value_out = fun(*args_cpy, **kwargs_cpy); py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function // Validate the return value of the python function
if (!py::isinstance<array>(py_value_out)) { if (!nb::isinstance<array>(py_value_out)) {
if (scalar_func_only) { if (scalar_func_only) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be a " << "whose gradient we want to compute should be a "
<< "scalar array; but " << py_value_out.get_type() << "scalar array; but " << type_name_str(py_value_out)
<< " was returned."; << " was returned.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!py::isinstance<py::tuple>(py_value_out)) { if (!nb::isinstance<nb::tuple>(py_value_out)) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a " << "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a " << "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but " << "scalar array (Union[array, Tuple[array, Any, ...]]); but "
<< py_value_out.get_type() << " was returned."; << type_name_str(py_value_out) << " was returned.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
py::tuple ret = py::cast<py::tuple>(py_value_out); nb::tuple ret = nb::cast<nb::tuple>(py_value_out);
if (ret.size() == 0) { if (ret.size() == 0) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
@ -182,14 +191,14 @@ auto py_value_and_grad(
<< "we got an empty tuple."; << "we got an empty tuple.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (!py::isinstance<array>(ret[0])) { if (!nb::isinstance<array>(ret[0])) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " The return value of the function " msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a " << "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a " << "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it " << "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
<< "was a tuple with the first value being of type " << "was a tuple with the first value being of type "
<< ret[0].get_type() << " ."; << type_name_str(ret[0]) << " .";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
@ -212,61 +221,60 @@ auto py_value_and_grad(
// In case 2 we return a tuple of the above. // In case 2 we return a tuple of the above.
// In case 3 we return a tuple containing a tuple and dict (sth like // In case 3 we return a tuple containing a tuple and dict (sth like
// (tuple(), dict(x=mx.array(5))) ). // (tuple(), dict(x=mx.array(5))) ).
py::object positional_grads; nb::object positional_grads;
py::object keyword_grads; nb::object keyword_grads;
py::object py_grads; nb::object py_grads;
// Collect the gradients for the positional arguments // Collect the gradients for the positional arguments
if (argnums.size() == 1) { if (argnums.size() == 1) {
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
} else if (argnums.size() > 1) { } else if (argnums.size() > 1) {
py::tuple grads_(argnums.size()); nb::list grads_;
for (int i = 0; i < argnums.size(); i++) { for (int i = 0; i < argnums.size(); i++) {
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]); grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
} }
positional_grads = py::cast<py::object>(grads_); positional_grads = nb::tuple(grads_);
} else { } else {
positional_grads = py::none(); positional_grads = nb::none();
} }
// No keyword argument gradients so return the tuple of gradients // No keyword argument gradients so return the tuple of gradients
if (argnames.size() == 0) { if (argnames.size() == 0) {
py_grads = positional_grads; py_grads = positional_grads;
} else { } else {
py::dict grads_; nb::dict grads_;
for (int i = 0; i < argnames.size(); i++) { for (int i = 0; i < argnames.size(); i++) {
auto& k = argnames[i]; auto& k = argnames[i];
grads_[k.c_str()] = tree_unflatten( grads_[k.c_str()] = tree_unflatten(
kwargs[k.c_str()], gradients, counts[i + argnums.size()]); kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
} }
keyword_grads = py::cast<py::object>(grads_); keyword_grads = grads_;
py_grads = py_grads = nb::make_tuple(positional_grads, keyword_grads);
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
} }
// Put the values back in the container // Put the values back in the container
py::object return_value = tree_unflatten(py_value_out, value); nb::object return_value = tree_unflatten(py_value_out, value);
return std::make_pair(return_value, py_grads); return std::make_pair(return_value, py_grads);
}; };
} }
auto py_vmap( auto py_vmap(
const py::function& fun, const nb::callable& fun,
const py::object& in_axes, const nb::object& in_axes,
const py::object& out_axes) { const nb::object& out_axes) {
return [fun, in_axes, out_axes](const py::args& args) { return [fun, in_axes, out_axes](const nb::args& args) {
auto axes_to_flat_tree = [](const py::object& tree, auto axes_to_flat_tree = [](const nb::object& tree,
const py::object& axes) { const nb::object& axes) {
auto tree_axes = tree_map( auto tree_axes = tree_map(
{tree, axes}, {tree, axes},
[](const std::vector<py::object>& inputs) { return inputs[1]; }); [](const std::vector<nb::object>& inputs) { return inputs[1]; });
std::vector<int> flat_axes; std::vector<int> flat_axes;
tree_visit(tree_axes, [&flat_axes](py::handle obj) { tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
if (obj.is_none()) { if (obj.is_none()) {
flat_axes.push_back(-1); flat_axes.push_back(-1);
} else if (py::isinstance<py::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj))); flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
} else { } else {
throw std::invalid_argument("[vmap] axis must be int or None."); throw std::invalid_argument("[vmap] axis must be int or None.");
} }
@ -280,7 +288,7 @@ auto py_vmap(
// py_value_out will hold the output of the python function in order to be // py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values // able to reconstruct the python tree of extra return values
py::object py_outputs; nb::object py_outputs;
auto vmap_fn = auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) { [&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
@ -305,24 +313,24 @@ auto py_vmap(
}; };
} }
std::unordered_map<size_t, py::object>& tree_cache() { std::unordered_map<size_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs // This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, py::object> tree_cache_; static std::unordered_map<size_t, nb::object> tree_cache_;
return tree_cache_; return tree_cache_;
} }
struct PyCompiledFun { struct PyCompiledFun {
py::function fun; nb::callable fun;
size_t fun_id; size_t fun_id;
py::object captured_inputs; nb::object captured_inputs;
py::object captured_outputs; nb::object captured_outputs;
bool shapeless; bool shapeless;
size_t num_outputs{0}; mutable size_t num_outputs{0};
PyCompiledFun( PyCompiledFun(
const py::function& fun, const nb::callable& fun,
py::object inputs, nb::object inputs,
py::object outputs, nb::object outputs,
bool shapeless) bool shapeless)
: fun(fun), : fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())), fun_id(reinterpret_cast<size_t>(fun.ptr())),
@ -342,7 +350,7 @@ struct PyCompiledFun {
num_outputs = other.num_outputs; num_outputs = other.num_outputs;
}; };
py::object operator()(const py::args& args, const py::kwargs& kwargs) { nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
// Flat array inputs // Flat array inputs
std::vector<array> inputs; std::vector<array> inputs;
@ -358,45 +366,45 @@ struct PyCompiledFun {
constexpr uint64_t dict_identifier = 18446744073709551521UL; constexpr uint64_t dict_identifier = 18446744073709551521UL;
// Flatten the tree with hashed constants and structure // Flatten the tree with hashed constants and structure
std::function<void(py::handle)> recurse; std::function<void(nb::handle)> recurse;
recurse = [&](py::handle obj) { recurse = [&](nb::handle obj) {
if (py::isinstance<py::list>(obj)) { if (nb::isinstance<nb::list>(obj)) {
auto l = py::cast<py::list>(obj); auto l = nb::cast<nb::list>(obj);
constants.push_back(list_identifier); constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) { for (int i = 0; i < l.size(); ++i) {
recurse(l[i]); recurse(l[i]);
} }
} else if (py::isinstance<py::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {
auto l = py::cast<py::tuple>(obj); auto l = nb::cast<nb::tuple>(obj);
constants.push_back(list_identifier); constants.push_back(list_identifier);
for (auto item : obj) { for (auto item : obj) {
recurse(item); recurse(item);
} }
} else if (py::isinstance<py::dict>(obj)) { } else if (nb::isinstance<nb::dict>(obj)) {
auto d = py::cast<py::dict>(obj); auto d = nb::cast<nb::dict>(obj);
constants.push_back(dict_identifier); constants.push_back(dict_identifier);
for (auto item : d) { for (auto item : d) {
auto r = py::hash(item.first); auto r = item.first.attr("__hash__");
constants.push_back(*reinterpret_cast<uint64_t*>(&r)); constants.push_back(*reinterpret_cast<uint64_t*>(&r));
recurse(item.second); recurse(item.second);
} }
} else if (py::isinstance<array>(obj)) { } else if (nb::isinstance<array>(obj)) {
inputs.push_back(py::cast<array>(obj)); inputs.push_back(nb::cast<array>(obj));
constants.push_back(array_identifier); constants.push_back(array_identifier);
} else if (py::isinstance<py::str>(obj)) { } else if (nb::isinstance<nb::str>(obj)) {
auto r = py::hash(obj); auto r = obj.attr("__hash__");
constants.push_back(*reinterpret_cast<uint64_t*>(&r)); constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::int_>(obj)) { } else if (nb::isinstance<nb::int_>(obj)) {
auto r = obj.cast<int64_t>(); auto r = nb::cast<int64_t>(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r)); constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::float_>(obj)) { } else if (nb::isinstance<nb::float_>(obj)) {
auto r = obj.cast<double>(); auto r = nb::cast<double>(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r)); constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays " msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, or strings), but received " << "or constants (floats, ints, or strings), but received "
<< "type " << obj.get_type() << "."; << "type " << type_name_str(obj) << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
}; };
@ -404,13 +412,12 @@ struct PyCompiledFun {
recurse(args); recurse(args);
int num_args = inputs.size(); int num_args = inputs.size();
recurse(kwargs); recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args]( auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) { const std::vector<array>& a) {
// Put tracers into captured inputs // Put tracers into captured inputs
std::vector<array> flat_in_captures; std::vector<array> flat_in_captures;
std::vector<array> trace_captures; std::vector<array> trace_captures;
if (!py::isinstance<py::none>(captured_inputs)) { if (!captured_inputs.is_none()) {
flat_in_captures = tree_flatten(captured_inputs, false); flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert( trace_captures.insert(
trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
@ -425,7 +432,7 @@ struct PyCompiledFun {
tree_cache().insert({fun_id, py_outputs}); tree_cache().insert({fun_id, py_outputs});
num_outputs = outputs.size(); num_outputs = outputs.size();
if (!py::isinstance<py::none>(captured_outputs)) { if (!captured_outputs.is_none()) {
auto flat_out_captures = tree_flatten(captured_outputs, false); auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert( outputs.insert(
outputs.end(), outputs.end(),
@ -434,13 +441,13 @@ struct PyCompiledFun {
} }
// Replace tracers with originals in captured inputs // Replace tracers with originals in captured inputs
if (!py::isinstance<py::none>(captured_inputs)) { if (!captured_inputs.is_none()) {
tree_replace(captured_inputs, trace_captures, flat_in_captures); tree_replace(captured_inputs, trace_captures, flat_in_captures);
} }
return outputs; return outputs;
}; };
if (!py::isinstance<py::none>(captured_inputs)) { if (!captured_inputs.is_none()) {
auto flat_in_captures = tree_flatten(captured_inputs, false); auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert( inputs.insert(
inputs.end(), inputs.end(),
@ -451,7 +458,7 @@ struct PyCompiledFun {
// Compile and call // Compile and call
auto outputs = auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) { if (!captured_outputs.is_none()) {
std::vector<array> captures( std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs), std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end())); std::make_move_iterator(outputs.end()));
@ -459,12 +466,16 @@ struct PyCompiledFun {
} }
// Put the outputs back in the container // Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id); nb::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_from_structure(py_outputs, outputs); return tree_unflatten_from_structure(py_outputs, outputs);
}
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);
}; };
~PyCompiledFun() { ~PyCompiledFun() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id); tree_cache().erase(fun_id);
detail::compile_erase(fun_id); detail::compile_erase(fun_id);
@ -476,35 +487,35 @@ struct PyCompiledFun {
class PyCheckpointedFun { class PyCheckpointedFun {
public: public:
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {} PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() { ~PyCheckpointedFun() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
fun_.release().dec_ref(); fun_.release().dec_ref();
} }
struct InnerFunction { struct InnerFunction {
py::object fun_; nb::object fun_;
py::object args_structure_; nb::object args_structure_;
std::weak_ptr<py::object> output_structure_; std::weak_ptr<nb::object> output_structure_;
InnerFunction( InnerFunction(
py::object fun, nb::object fun,
py::object args_structure, nb::object args_structure,
std::weak_ptr<py::object> output_structure) std::weak_ptr<nb::object> output_structure)
: fun_(std::move(fun)), : fun_(std::move(fun)),
args_structure_(std::move(args_structure)), args_structure_(std::move(args_structure)),
output_structure_(output_structure) {} output_structure_(output_structure) {}
~InnerFunction() { ~InnerFunction() {
py::gil_scoped_acquire gil; nb::gil_scoped_acquire gil;
fun_.release().dec_ref(); fun_.release().dec_ref();
args_structure_.release().dec_ref(); args_structure_.release().dec_ref();
} }
std::vector<array> operator()(const std::vector<array>& inputs) { std::vector<array> operator()(const std::vector<array>& inputs) {
auto args = py::cast<py::tuple>( auto args = nb::cast<nb::tuple>(
tree_unflatten_from_structure(args_structure_, inputs)); tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] = auto [outputs, output_structure] =
tree_flatten_with_structure(fun_(*args[0], **args[1]), false); tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
@ -515,9 +526,9 @@ class PyCheckpointedFun {
} }
}; };
py::object operator()(const py::args& args, const py::kwargs& kwargs) { nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
auto output_structure = std::make_shared<py::object>(); auto output_structure = std::make_shared<nb::object>();
auto full_args = py::make_tuple(args, kwargs); auto full_args = nb::make_tuple(args, kwargs);
auto [inputs, args_structure] = auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false); tree_flatten_with_structure(full_args, false);
@ -527,26 +538,27 @@ class PyCheckpointedFun {
return tree_unflatten_from_structure(*output_structure, outputs); return tree_unflatten_from_structure(*output_structure, outputs);
} }
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);
}
private: private:
py::function fun_; nb::callable fun_;
}; };
void init_transforms(py::module_& m) { void init_transforms(nb::module_& m) {
py::options options;
options.disable_function_signatures();
m.def( m.def(
"eval", "eval",
[](const py::args& args) { [](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false); std::vector<array> arrays = tree_flatten(args, false);
{ {
py::gil_scoped_release nogil; nb::gil_scoped_release nogil;
eval(arrays); eval(arrays);
} }
}, },
nb::arg(),
nb::sig("def eval(*args) -> None"),
R"pbdoc( R"pbdoc(
eval(*args) -> None
Evaluate an :class:`array` or tree of :class:`array`. Evaluate an :class:`array` or tree of :class:`array`.
Args: Args:
@ -557,19 +569,15 @@ void init_transforms(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"jvp", "jvp",
[](const py::function& fun, [](const nb::callable& fun,
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents) { const std::vector<array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) { auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size()); auto out = fun(*nb::cast(primals));
for (int i = 0; i < primals.size(); ++i) { if (nb::isinstance<array>(out)) {
args[i] = primals[i]; return std::vector<array>{nb::cast<array>(out)};
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
} else { } else {
return py::cast<std::vector<array>>(out); return nb::cast<std::vector<array>>(out);
} }
}; };
return jvp(vfun, primals, tangents); return jvp(vfun, primals, tangents);
@ -577,17 +585,16 @@ void init_transforms(py::module_& m) {
"fun"_a, "fun"_a,
"primals"_a, "primals"_a,
"tangents"_a, "tangents"_a,
nb::sig(
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
R"pbdoc( R"pbdoc(
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
Compute the Jacobian-vector product. Compute the Jacobian-vector product.
This computes the product of the Jacobian of a function ``fun`` evaluated This computes the product of the Jacobian of a function ``fun`` evaluated
at ``primals`` with the ``tangents``. at ``primals`` with the ``tangents``.
Args: Args:
fun (function): A function which takes a variable number of :class:`array` fun (callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`. and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian. evaluate the Jacobian.
@ -601,19 +608,15 @@ void init_transforms(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"vjp", "vjp",
[](const py::function& fun, [](const nb::callable& fun,
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents) { const std::vector<array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) { auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size()); auto out = fun(*nb::cast(primals));
for (int i = 0; i < primals.size(); ++i) { if (nb::isinstance<array>(out)) {
args[i] = primals[i]; return std::vector<array>{nb::cast<array>(out)};
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
} else { } else {
return py::cast<std::vector<array>>(out); return nb::cast<std::vector<array>>(out);
} }
}; };
return vjp(vfun, primals, cotangents); return vjp(vfun, primals, cotangents);
@ -621,16 +624,16 @@ void init_transforms(py::module_& m) {
"fun"_a, "fun"_a,
"primals"_a, "primals"_a,
"cotangents"_a, "cotangents"_a,
nb::sig(
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
R"pbdoc( R"pbdoc(
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
Compute the vector-Jacobian product. Compute the vector-Jacobian product.
Computes the product of the ``cotangents`` with the Jacobian of a Computes the product of the ``cotangents`` with the Jacobian of a
function ``fun`` evaluated at ``primals``. function ``fun`` evaluated at ``primals``.
Args: Args:
fun (function): A function which takes a variable number of :class:`array` fun (callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`. and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian. evaluate the Jacobian.
@ -644,20 +647,20 @@ void init_transforms(py::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"value_and_grad", "value_and_grad",
[](const py::function& fun, [](const nb::callable& fun,
const std::optional<IntOrVec>& argnums, const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) { const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] = auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames); validate_argnums_argnames(argnums, argnames);
return py::cpp_function(py_value_and_grad( return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
}, },
"fun"_a, "fun"_a,
"argnums"_a = std::nullopt, "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig(
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
R"pbdoc( R"pbdoc(
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the value and gradient of ``fun``. Returns a function which computes the value and gradient of ``fun``.
The function passed to :func:`value_and_grad` should return either The function passed to :func:`value_and_grad` should return either
@ -688,7 +691,7 @@ void init_transforms(py::module_& m) {
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
Args: Args:
fun (function): A function which takes a variable number of fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a scalar output :class:`array` or a tuple the first element a scalar output :class:`array` or a tuple the first element
of which should be a scalar :class:`array`. of which should be a scalar :class:`array`.
@ -702,34 +705,34 @@ void init_transforms(py::module_& m) {
no gradients for keyword arguments by default. no gradients for keyword arguments by default.
Returns: Returns:
function: A function which returns a tuple where the first element callable: A function which returns a tuple where the first element
is the output of `fun` and the second element is the gradients w.r.t. is the output of `fun` and the second element is the gradients w.r.t.
the loss. the loss.
)pbdoc"); )pbdoc");
m.def( m.def(
"grad", "grad",
[](const py::function& fun, [](const nb::callable& fun,
const std::optional<IntOrVec>& argnums, const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) { const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] = auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames); validate_argnums_argnames(argnums, argnames);
auto fn = auto fn =
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
return py::cpp_function( return nb::cpp_function(
[fn](const py::args& args, const py::kwargs& kwargs) { [fn](const nb::args& args, const nb::kwargs& kwargs) {
return fn(args, kwargs).second; return fn(args, kwargs).second;
}); });
}, },
"fun"_a, "fun"_a,
"argnums"_a = std::nullopt, "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig(
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
R"pbdoc( R"pbdoc(
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the gradient of ``fun``. Returns a function which computes the gradient of ``fun``.
Args: Args:
fun (function): A function which takes a variable number of fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a scalar output :class:`array`. a scalar output :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices) argnums (int or list(int), optional): Specify the index (or indices)
@ -742,26 +745,26 @@ void init_transforms(py::module_& m) {
no gradients for keyword arguments by default. no gradients for keyword arguments by default.
Returns: Returns:
function: A function which has the same input arguments as ``fun`` and callable: A function which has the same input arguments as ``fun`` and
returns the gradient(s). returns the gradient(s).
)pbdoc"); )pbdoc");
m.def( m.def(
"vmap", "vmap",
[](const py::function& fun, [](const nb::callable& fun,
const py::object& in_axes, const nb::object& in_axes,
const py::object& out_axes) { const nb::object& out_axes) {
return py::cpp_function(py_vmap(fun, in_axes, out_axes)); return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
}, },
"fun"_a, "fun"_a,
"in_axes"_a = 0, "in_axes"_a = 0,
"out_axes"_a = 0, "out_axes"_a = 0,
nb::sig(
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
R"pbdoc( R"pbdoc(
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
Returns a vectorized version of ``fun``. Returns a vectorized version of ``fun``.
Args: Args:
fun (function): A function which takes a variable number of fun (callable): A function which takes a variable number of
:class:`array` or a tree of :class:`array` and returns :class:`array` or a tree of :class:`array` and returns
a variable number of :class:`array` or a tree of :class:`array`. a variable number of :class:`array` or a tree of :class:`array`.
in_axes (int, optional): An integer or a valid prefix tree of the in_axes (int, optional): An integer or a valid prefix tree of the
@ -774,16 +777,16 @@ void init_transforms(py::module_& m) {
Defaults to ``0``. Defaults to ``0``.
Returns: Returns:
function: The vectorized function. callable: The vectorized function.
)pbdoc"); )pbdoc");
m.def( m.def(
"export_to_dot", "export_to_dot",
[](py::object file, const py::args& args) { [](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args); std::vector<array> arrays = tree_flatten(args);
if (py::isinstance<py::str>(file)) { if (nb::isinstance<nb::str>(file)) {
std::ofstream out(py::cast<std::string>(file)); std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays); export_to_dot(out, arrays);
} else if (py::hasattr(file, "write")) { } else if (nb::hasattr(file, "write")) {
std::ostringstream out; std::ostringstream out;
export_to_dot(out, arrays); export_to_dot(out, arrays);
auto write = file.attr("write"); auto write = file.attr("write");
@ -793,57 +796,50 @@ void init_transforms(py::module_& m) {
"export_to_dot accepts file-like objects or strings to be used as filenames"); "export_to_dot accepts file-like objects or strings to be used as filenames");
} }
}, },
"file"_a); "file"_a,
"args"_a);
m.def( m.def(
"compile", "compile",
[](const py::function& fun, [](const nb::callable& fun,
const py::object& inputs, const nb::object& inputs,
const py::object& outputs, const nb::object& outputs,
bool shapeless) { bool shapeless) {
py::options options; // Try to get the name
options.disable_function_signatures(); auto n = fun.attr("__name__");
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
std::ostringstream doc;
auto name = fun.attr("__name__").cast<std::string>();
doc << name;
// Try to get the signature // Try to get the signature
auto inspect = py::module::import("inspect"); std::ostringstream sig;
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) { sig << "def " << name;
doc << inspect.attr("signature")(fun) auto inspect = nb::module_::import_("inspect");
.attr("__str__")() if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
.cast<std::string>(); sig << nb::cast<std::string>(
inspect.attr("signature")(fun).attr("__str__")());
} else {
sig << "(*args, **kwargs)";
} }
// Try to get the doc string // Try to get the doc string
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) { auto d = inspect.attr("getdoc")(fun);
doc << "\n\n"; std::string doc =
auto dstr = d.cast<std::string>(); d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
// Add spaces to match first line indentation with remainder of
// docstring auto sig_str = sig.str();
int i = 0; return nb::cpp_function(
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
doc << ' ';
}
doc << dstr;
}
auto doc_str = doc.str();
return py::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless}, PyCompiledFun{fun, inputs, outputs, shapeless},
py::name(name.c_str()), nb::name(name.c_str()),
py::doc(doc_str.c_str())); nb::sig(sig_str.c_str()),
doc.c_str());
}, },
"fun"_a, "fun"_a,
"inputs"_a = std::nullopt, "inputs"_a = nb::none(),
"outputs"_a = std::nullopt, "outputs"_a = nb::none(),
"shapeless"_a = false, "shapeless"_a = false,
R"pbdoc( R"pbdoc(
compile(fun: function) -> function
Returns a compiled function which produces the same output as ``fun``. Returns a compiled function which produces the same output as ``fun``.
Args: Args:
fun (function): A function which takes a variable number of fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`. a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during inputs (list or dict, optional): These inputs will be captured during
@ -864,15 +860,13 @@ void init_transforms(py::module_& m) {
``shapeless`` set to ``True``. Default: ``False`` ``shapeless`` set to ``True``. Default: ``False``
Returns: Returns:
function: A compiled function which has the same input arguments callable: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s). as ``fun`` and returns the the same output(s).
)pbdoc"); )pbdoc");
m.def( m.def(
"disable_compile", "disable_compile",
&disable_compile, &disable_compile,
R"pbdoc( R"pbdoc(
disable_compile() -> None
Globally disable compilation. Setting the environment variable Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation. ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc"); )pbdoc");
@ -880,17 +874,15 @@ void init_transforms(py::module_& m) {
"enable_compile", "enable_compile",
&enable_compile, &enable_compile,
R"pbdoc( R"pbdoc(
enable_compile() -> None
Globally enable compilation. This will override the environment Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set. variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc"); )pbdoc");
m.def( m.def(
"checkpoint", "checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); }, [](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a); "fun"_a);
// Register static Python object cleanup before the interpreter exits // Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit"); auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
} }

View File

@ -2,16 +2,16 @@
#include "python/src/trees.h" #include "python/src/trees.h"
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) { void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(py::handle)> recurse; std::function<void(nb::handle)> recurse;
recurse = [&](py::handle subtree) { recurse = [&](nb::handle subtree) {
if (py::isinstance<py::list>(subtree) || if (nb::isinstance<nb::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) { nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) { for (auto item : subtree) {
recurse(item); recurse(item);
} }
} else if (py::isinstance<py::dict>(subtree)) { } else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) { for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second); recurse(item.second);
} }
} else { } else {
@ -23,63 +23,63 @@ void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
} }
template <typename T, typename U, typename V> template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) { void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size(); int len = nb::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) { for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) || if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) { nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree."); "[tree_map] Additional input tree is not a valid prefix of the first tree.");
} }
} }
} }
py::object tree_map( nb::object tree_map(
const std::vector<py::object>& trees, const std::vector<nb::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) { std::function<nb::object(const std::vector<nb::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse; std::function<nb::object(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) { recurse = [&](const std::vector<nb::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) { if (nb::isinstance<nb::list>(subtrees[0])) {
py::list l; nb::list l;
std::vector<py::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees); validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) { for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) { if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i]; items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else { } else {
items[j] = subtrees[j]; items[j] = subtrees[j];
} }
} }
l.append(recurse(items)); l.append(recurse(items));
} }
return py::cast<py::object>(l); return nb::cast<nb::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) { } else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees // Check the rest of the subtrees
std::vector<py::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size(); int len = nb::cast<nb::tuple>(subtrees[0]).size();
py::tuple l(len); nb::list l;
validate_subtrees<py::tuple, py::list, py::dict>(subtrees); validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) { if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i]; items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else { } else {
items[j] = subtrees[j]; items[j] = subtrees[j];
} }
} }
l[i] = recurse(items); l.append(recurse(items));
} }
return py::cast<py::object>(l); return nb::cast<nb::object>(nb::tuple(l));
} else if (py::isinstance<py::dict>(subtrees[0])) { } else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees); validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
py::dict d; nb::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) { for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) { if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]); auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) { if (!subdict.contains(item.first)) {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree."); "[tree_map] Tree is not a valid prefix tree of the first tree.");
@ -91,7 +91,7 @@ py::object tree_map(
} }
d[item.first] = recurse(items); d[item.first] = recurse(items);
} }
return py::cast<py::object>(d); return nb::cast<nb::object>(d);
} else { } else {
return transform(subtrees); return transform(subtrees);
} }
@ -99,40 +99,40 @@ py::object tree_map(
return recurse(trees); return recurse(trees);
} }
py::object tree_map( nb::object tree_map(
py::object tree, nb::object tree,
std::function<py::object(py::handle)> transform) { std::function<nb::object(nb::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) { return tree_map({tree}, [&](std::vector<nb::object> inputs) {
return transform(inputs[0]); return transform(inputs[0]);
}); });
} }
void tree_visit_update( void tree_visit_update(
py::object tree, nb::object tree,
std::function<py::object(py::handle)> visitor) { std::function<nb::object(nb::handle)> visitor) {
std::function<py::object(py::handle)> recurse; std::function<nb::object(nb::handle)> recurse;
recurse = [&](py::handle subtree) { recurse = [&](nb::handle subtree) {
if (py::isinstance<py::list>(subtree)) { if (nb::isinstance<nb::list>(subtree)) {
auto l = py::cast<py::list>(subtree); auto l = nb::cast<nb::list>(subtree);
for (int i = 0; i < l.size(); ++i) { for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]); l[i] = recurse(l[i]);
} }
return py::cast<py::object>(l); return nb::cast<nb::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) { } else if (nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) { for (auto item : subtree) {
recurse(item); recurse(item);
} }
return py::cast<py::object>(subtree); return nb::cast<nb::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) { } else if (nb::isinstance<nb::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree); auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) { for (auto item : d) {
d[item.first] = recurse(item.second); d[item.first] = recurse(item.second);
} }
return py::cast<py::object>(d); return nb::cast<nb::object>(d);
} else if (py::isinstance<array>(subtree)) { } else if (nb::isinstance<array>(subtree)) {
return visitor(subtree); return visitor(subtree);
} else { } else {
return py::cast<py::object>(subtree); return nb::cast<nb::object>(subtree);
} }
}; };
recurse(tree); recurse(tree);
@ -141,36 +141,36 @@ void tree_visit_update(
// Fill a pytree (recursive dict or list of dict or list) // Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays // in place with the given arrays
// Non dict or list nodes are ignored // Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) { void tree_fill(nb::object& tree, const std::vector<array>& values) {
size_t index = 0; size_t index = 0;
tree_visit_update( tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); }); tree, [&](nb::handle node) { return nb::cast(values[index++]); });
} }
// Replace all the arrays from the src values with the dst values in the tree // Replace all the arrays from the src values with the dst values in the tree
void tree_replace( void tree_replace(
py::object& tree, nb::object& tree,
const std::vector<array>& src, const std::vector<array>& src,
const std::vector<array>& dst) { const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst; std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) { for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]}); src_to_dst.insert({src[i].id(), dst[i]});
} }
tree_visit_update(tree, [&](py::handle node) { tree_visit_update(tree, [&](nb::handle node) {
auto arr = py::cast<array>(node); auto arr = nb::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second); return nb::cast(it->second);
} }
return py::cast(arr); return nb::cast(arr);
}); });
} }
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) { std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<array> flat_tree; std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) { tree_visit(tree, [&](nb::handle obj) {
if (py::isinstance<array>(obj)) { if (nb::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj)); flat_tree.push_back(nb::cast<array>(obj));
} else if (strict) { } else if (strict) {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays"); "[tree_flatten] The argument should contain only arrays");
@ -180,24 +180,24 @@ std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
return flat_tree; return flat_tree;
} }
py::object tree_unflatten( nb::object tree_unflatten(
py::object tree, nb::object tree,
const std::vector<array>& values, const std::vector<array>& values,
int index /* = 0 */) { int index /* = 0 */) {
return tree_map(tree, [&](py::handle obj) { return tree_map(tree, [&](nb::handle obj) {
if (py::isinstance<array>(obj)) { if (nb::isinstance<array>(obj)) {
return py::cast(values[index++]); return nb::cast(values[index++]);
} else { } else {
return py::cast<py::object>(obj); return nb::cast<nb::object>(obj);
} }
}); });
} }
py::object structure_sentinel() { nb::object structure_sentinel() {
static py::object sentinel; static nb::object sentinel;
if (sentinel.ptr() == nullptr) { if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel); sentinel = nb::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever // probably not needed but this should make certain that we won't ever
// delete the sentinel // delete the sentinel
sentinel.inc_ref(); sentinel.inc_ref();
@ -206,19 +206,19 @@ py::object structure_sentinel() {
return sentinel; return sentinel;
} }
std::pair<std::vector<array>, py::object> tree_flatten_with_structure( std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
py::object tree, nb::object tree,
bool strict /* = true */) { bool strict /* = true */) {
auto sentinel = structure_sentinel(); auto sentinel = structure_sentinel();
std::vector<array> flat_tree; std::vector<array> flat_tree;
auto structure = tree_map( auto structure = tree_map(
tree, tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) { [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (py::isinstance<array>(obj)) { if (nb::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj)); flat_tree.push_back(nb::cast<array>(obj));
return sentinel; return sentinel;
} else if (!strict) { } else if (!strict) {
return py::cast<py::object>(obj); return nb::cast<nb::object>(obj);
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays"); "[tree_flatten] The argument should contain only arrays");
@ -228,16 +228,16 @@ std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
return {flat_tree, structure}; return {flat_tree, structure};
} }
py::object tree_unflatten_from_structure( nb::object tree_unflatten_from_structure(
py::object structure, nb::object structure,
const std::vector<array>& values, const std::vector<array>& values,
int index /* = 0 */) { int index /* = 0 */) {
auto sentinel = structure_sentinel(); auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) { return tree_map(structure, [&](nb::handle obj) {
if (obj.is(sentinel)) { if (obj.is(sentinel)) {
return py::cast(values[index++]); return nb::cast(values[index++]);
} else { } else {
return py::cast<py::object>(obj); return nb::cast<nb::object>(obj);
} }
}); });
} }

View File

@ -1,38 +1,37 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h>
#include "mlx/array.h" #include "mlx/array.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace mlx::core; using namespace mlx::core;
void tree_visit(py::object tree, std::function<void(py::handle)> visitor); void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
py::object tree_map( nb::object tree_map(
const std::vector<py::object>& trees, const std::vector<nb::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform); std::function<nb::object(const std::vector<nb::object>&)> transform);
py::object tree_map( nb::object tree_map(
py::object tree, nb::object tree,
std::function<py::object(py::handle)> transform); std::function<nb::object(nb::handle)> transform);
void tree_visit_update( void tree_visit_update(
py::object tree, nb::object tree,
std::function<py::object(py::handle)> visitor); std::function<nb::object(nb::handle)> visitor);
/** /**
* Fill a pytree (recursive dict or list of dict or list) in place with the * Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */ * given arrays. */
void tree_fill(py::object& tree, const std::vector<array>& values); void tree_fill(nb::object& tree, const std::vector<array>& values);
/** /**
* Replace all the arrays from the src values with the dst values in the * Replace all the arrays from the src values with the dst values in the
* tree. * tree.
*/ */
void tree_replace( void tree_replace(
py::object& tree, nb::object& tree,
const std::vector<array>& src, const std::vector<array>& src,
const std::vector<array>& dst); const std::vector<array>& dst);
@ -40,21 +39,21 @@ void tree_replace(
* Flatten a tree into a vector of arrays. If strict is true, then the * Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array. * function will throw if the tree contains a leaf which is not an array.
*/ */
std::vector<array> tree_flatten(py::object tree, bool strict = true); std::vector<array> tree_flatten(nb::object tree, bool strict = true);
/** /**
* Unflatten a tree from a vector of arrays. * Unflatten a tree from a vector of arrays.
*/ */
py::object tree_unflatten( nb::object tree_unflatten(
py::object tree, nb::object tree,
const std::vector<array>& values, const std::vector<array>& values,
int index = 0); int index = 0);
std::pair<std::vector<array>, py::object> tree_flatten_with_structure( std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
py::object tree, nb::object tree,
bool strict = true); bool strict = true);
py::object tree_unflatten_from_structure( nb::object tree_unflatten_from_structure(
py::object structure, nb::object structure,
const std::vector<array>& values, const std::vector<array>& values,
int index = 0); int index = 0);

View File

@ -1,81 +0,0 @@
#include "mlx/utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <optional>
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
// Slightly different from the original, with python context on init we are not
// in the context yet. Only create the inner context on enter then delete on
// exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_utils(py::module_& m) {
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(py::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<py::type>& exc_type,
const std::optional<py::object>& exc_value,
const std::optional<py::object>& traceback) { scm.exit(); });
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
}

View File

@ -1,23 +1,22 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <numeric> #include <numeric>
#include <optional>
#include <variant> #include <variant>
#include <pybind11/complex.h> #include <nanobind/nanobind.h>
#include <pybind11/pybind11.h> #include <nanobind/stl/complex.h>
#include <pybind11/stl.h> #include <nanobind/stl/variant.h>
#include "mlx/array.h" #include "mlx/array.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace mlx::core; using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>; using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std:: using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>; variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) { inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes; std::vector<int> axes;
@ -32,31 +31,36 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes; return axes;
} }
inline array to_array_with_accessor(py::object obj) { inline array to_array_with_accessor(nb::object obj) {
if (py::hasattr(obj, "__mlx_array__")) { if (nb::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>(); return nb::cast<array>(obj.attr("__mlx_array__")());
} else if (nb::isinstance<array>(obj)) {
return nb::cast<array>(obj);
} else { } else {
return obj.cast<array>(); std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
} }
} }
inline array to_array( inline array to_array(
const ScalarOrArray& v, const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) { std::optional<Dtype> dtype = std::nullopt) {
if (auto pv = std::get_if<py::bool_>(&v); pv) { if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), dtype.value_or(bool_)); return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) { } else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32); auto out_t = dtype.value_or(int32);
// bool_ is an exception and is always promoted // bool_ is an exception and is always promoted
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t); return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<py::float_>(&v); pv) { } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32); auto out_t = dtype.value_or(float32);
return array( return array(
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32); nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64); return array(static_cast<complex64_t>(*pv), complex64);
} else { } else {
return to_array_with_accessor(std::get<py::object>(v)); return to_array_with_accessor(std::get<nb::object>(v));
} }
} }
@ -68,14 +72,14 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type // - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type // - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone // - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<py::object>(&a); pa) { if (auto pa = std::get_if<nb::object>(&a); pa) {
auto arr_a = to_array_with_accessor(*pa); auto arr_a = to_array_with_accessor(*pa);
if (auto pb = std::get_if<py::object>(&b); pb) { if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb); auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b}; return {arr_a, arr_b};
} }
return {arr_a, to_array(b, arr_a.dtype())}; return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<py::object>(&b); pb) { } else if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb); auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b}; return {to_array(a, arr_b.dtype()), arr_b};
} else { } else {

View File

@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.dtype, mx.bool_) self.assertEqual(y.dtype, mx.bool_)
self.assertEqual(y.item(), True) self.assertEqual(y.item(), True)
# y = mx.array(x, mx.complex64) y = mx.array(x, mx.complex64)
# self.assertEqual(y.dtype, mx.complex64) self.assertEqual(y.dtype, mx.complex64)
# self.assertEqual(y.item(), 3.0+0j) self.assertEqual(y.item(), 3.0 + 0j)
def test_array_repr(self): def test_array_repr(self):
x = mx.array(True) x = mx.array(True)
@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase):
# check if it throws an error when dtype is not supported (bfloat16) # check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16) x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(RuntimeError): with self.assertRaises(TypeError):
pickle.dumps(x) pickle.dumps(x)
def test_array_copy(self): def test_array_copy(self):
@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqualArray(y, x - 1) self.assertEqualArray(y, x - 1)
def test_indexing(self): def test_indexing(self):
# Only ellipsis is a no-op
a_mlx = mx.array([1])[...]
self.assertEqual(a_mlx.shape, (1,))
self.assertEqual(a_mlx.item(), 1)
# Basic content check, slice indexing # Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32) a_npy = np.arange(64, dtype=np.float32)
a_mlx = mx.array(a_npy) a_mlx = mx.array(a_npy)
@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase):
for mlx_dtype, tf_dtype, np_dtype in dtypes_list: for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype) a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype) a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(a_tf) a_mx = mx.array(np.array(a_tf))
for f in [ for f in [
lambda x: x, lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T, lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,

View File

@ -134,7 +134,9 @@ class GenerateStubs(Command):
pass pass
def run(self) -> None: def run(self) -> None:
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"]) subprocess.run(
["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"]
)
# Read the content of README.md # Read the content of README.md
@ -165,7 +167,7 @@ if __name__ == "__main__":
include_package_data=True, include_package_data=True,
extras_require={ extras_require={
"testing": ["numpy", "torch"], "testing": ["numpy", "torch"],
"dev": ["pre-commit", "pybind11-stubgen"], "dev": ["pre-commit"],
}, },
ext_modules=[CMakeExtension("mlx.core")], ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},