Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun 2024-03-18 20:12:25 -07:00 committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 2343 additions and 2344 deletions

View File

@ -31,8 +31,7 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@ -44,7 +43,8 @@ jobs:
- run:
name: Generate package stubs
command: |
python3 setup.py generate_stubs
echo "stubs"
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Run Python tests
command: |
@ -80,8 +80,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install numpy
pip install torch
pip install tensorflow
@ -95,7 +94,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Run Python tests
command: |
@ -144,9 +143,8 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
@ -161,7 +159,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Build Python package
command: |
@ -209,9 +207,8 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
@ -219,7 +216,7 @@ jobs:
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel

View File

@ -146,8 +146,12 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@ -29,8 +29,8 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
}
templates_path = ["_templates"]
@ -59,3 +59,14 @@ html_theme_options = {
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
wrapped = app.registry.documenters["function"].can_document_member
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool:
return "nanobind.nb_func" in str(type(member)) or wrapped(
member, *args, **kwargs
)
app.registry.documenters["function"].can_document_member = nanobind_function_patch

View File

@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
pip install git+https://github.com/wjakob/nanobind.git
Then simply build and install it using pip:
Then simply build and install MLX using pip:
.. code-block:: shell

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cstring>
#include <fstream>
@ -122,8 +121,6 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
out_stream->write(header.str().c_str(), header.str().length());
out_stream->write(a.data<char>(), a.nbytes());
return;
}
/** Save array to file in .npy format */

View File

@ -7,6 +7,25 @@
namespace mlx::core {
struct complex64_t;
struct complex128_t;
template <typename T>
static constexpr bool can_convert_to_complex128 =
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
struct complex128_t : public std::complex<double> {
complex128_t(double v, double u) : std::complex<double>(v, u){};
complex128_t(std::complex<double> v) : std::complex<double>(v){};
template <
typename T,
typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
complex128_t(T x) : std::complex<double>(x){};
operator float() const {
return real();
};
};
template <typename T>
static constexpr bool can_convert_to_complex64 =

View File

@ -1,3 +1,7 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
requires = [
"setuptools>=42",
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
"cmake>=3.24",
]
build-backend = "setuptools.build_meta"

View File

@ -1,7 +1,10 @@
pybind11_add_module(
nanobind_add_module(
core
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@ -15,7 +18,6 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

File diff suppressed because it is too large Load Diff

122
python/src/buffer.h Normal file
View File

@ -0,0 +1,122 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <optional>
#include <nanobind/nanobind.h>
#include "mlx/array.h"
#include "mlx/utils.h"
// Only defined in >= Python 3.9
// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3
#ifndef Py_bf_getbuffer
#define Py_bf_getbuffer 1
#define Py_bf_releasebuffer 2
#endif
namespace nb = nanobind;
using namespace mlx::core;
std::string buffer_format(const array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case bool_:
return "?";
case uint8:
return "B";
case uint16:
return "H";
case uint32:
return "I";
case uint64:
return "Q";
case int8:
return "b";
case int16:
return "h";
case int32:
return "i";
case int64:
return "q";
case float16:
return "e";
case float32:
return "f";
case bfloat16:
return "B";
case complex64:
return "Zf\0";
default: {
std::ostringstream os;
os << "bad dtype: " << a.dtype();
throw std::runtime_error(os.str());
}
}
}
struct buffer_info {
std::string format;
std::vector<ssize_t> shape;
std::vector<ssize_t> strides;
buffer_info(
const std::string& format,
std::vector<ssize_t> shape_in,
std::vector<ssize_t> strides_in)
: format(format),
shape(std::move(shape_in)),
strides(std::move(strides_in)) {}
buffer_info(const buffer_info&) = delete;
buffer_info& operator=(const buffer_info&) = delete;
buffer_info(buffer_info&& other) noexcept {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info&& rhs) noexcept {
format = std::move(rhs.format);
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
return *this;
}
};
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
std::memset(view, 0, sizeof(Py_buffer));
auto a = nb::cast<array>(nb::handle(obj));
if (!a.is_evaled()) {
nb::gil_scoped_release nogil;
a.eval();
}
std::vector<ssize_t> shape(a.shape().begin(), a.shape().end());
std::vector<ssize_t> strides(a.strides().begin(), a.strides().end());
for (auto& s : strides) {
s *= a.itemsize();
}
buffer_info* info =
new buffer_info(buffer_format(a), std::move(shape), std::move(strides));
view->obj = obj;
view->ndim = a.ndim();
view->internal = info;
view->buf = a.data<void>();
view->itemsize = a.itemsize();
view->len = a.size();
view->readonly = false;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = const_cast<char*>(info->format.c_str());
}
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->strides = info->strides.data();
view->shape = info->shape.data();
}
Py_INCREF(view->obj);
return 0;
}
extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) {
delete (buffer_info*)view->internal;
}

View File

@ -1,11 +1,11 @@
// init_constants.cpp
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include <limits>
namespace py = pybind11;
namespace nb = nanobind;
void init_constants(py::module_& m) {
void init_constants(nb::module_& m) {
m.attr("Inf") = std::numeric_limits<double>::infinity();
m.attr("Infinity") = std::numeric_limits<double>::infinity();
m.attr("NAN") = NAN;
@ -19,6 +19,6 @@ void init_constants(py::module_& m) {
m.attr("inf") = std::numeric_limits<double>::infinity();
m.attr("infty") = std::numeric_limits<double>::infinity();
m.attr("nan") = NAN;
m.attr("newaxis") = pybind11::none();
m.attr("newaxis") = nb::none();
m.attr("pi") = 3.1415926535897932384626433;
}

155
python/src/convert.cpp Normal file
View File

@ -0,0 +1,155 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/stl/complex.h>
#include "python/src/convert.h"
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
template <>
struct ndarray_traits<bfloat16_t> {
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
template <typename T>
array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const std::vector<int>& shape,
Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
auto data_ptr = nd_array.data();
return array(static_cast<const T*>(data_ptr), shape, dtype);
}
array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
// Compute the shape and size
std::vector<int> shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(nd_array.shape(i));
}
auto type = nd_array.dtype();
// Copy data and make array
if (type == nb::dtype<bool>()) {
return nd_array_to_mlx_contiguous<bool>(
nd_array, shape, dtype.value_or(bool_));
} else if (type == nb::dtype<uint8_t>()) {
return nd_array_to_mlx_contiguous<uint8_t>(
nd_array, shape, dtype.value_or(uint8));
} else if (type == nb::dtype<uint16_t>()) {
return nd_array_to_mlx_contiguous<uint16_t>(
nd_array, shape, dtype.value_or(uint16));
} else if (type == nb::dtype<uint32_t>()) {
return nd_array_to_mlx_contiguous<uint32_t>(
nd_array, shape, dtype.value_or(uint32));
} else if (type == nb::dtype<uint64_t>()) {
return nd_array_to_mlx_contiguous<uint64_t>(
nd_array, shape, dtype.value_or(uint64));
} else if (type == nb::dtype<int8_t>()) {
return nd_array_to_mlx_contiguous<int8_t>(
nd_array, shape, dtype.value_or(int8));
} else if (type == nb::dtype<int16_t>()) {
return nd_array_to_mlx_contiguous<int16_t>(
nd_array, shape, dtype.value_or(int16));
} else if (type == nb::dtype<int32_t>()) {
return nd_array_to_mlx_contiguous<int32_t>(
nd_array, shape, dtype.value_or(int32));
} else if (type == nb::dtype<int64_t>()) {
return nd_array_to_mlx_contiguous<int64_t>(
nd_array, shape, dtype.value_or(int64));
} else if (type == nb::dtype<float16_t>()) {
return nd_array_to_mlx_contiguous<float16_t>(
nd_array, shape, dtype.value_or(float16));
} else if (type == nb::dtype<bfloat16_t>()) {
return nd_array_to_mlx_contiguous<bfloat16_t>(
nd_array, shape, dtype.value_or(bfloat16));
} else if (type == nb::dtype<float>()) {
return nd_array_to_mlx_contiguous<float>(
nd_array, shape, dtype.value_or(float32));
} else if (type == nb::dtype<double>()) {
return nd_array_to_mlx_contiguous<double>(
nd_array, shape, dtype.value_or(float32));
} else if (type == nb::dtype<std::complex<float>>()) {
return nd_array_to_mlx_contiguous<complex64_t>(
nd_array, shape, dtype.value_or(complex64));
} else if (type == nb::dtype<std::complex<double>>()) {
return nd_array_to_mlx_contiguous<complex128_t>(
nd_array, shape, dtype.value_or(complex64));
} else {
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
}
}
template <typename Lib, typename T>
nb::ndarray<Lib> mlx_to_nd_array(
array a,
std::optional<nb::dlpack::dtype> t = {}) {
// Eval if not already evaled
if (!a.is_evaled()) {
nb::gil_scoped_release nogil;
a.eval();
}
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
return nb::ndarray<Lib>(
a.data<T>(),
a.ndim(),
shape.data(),
nb::handle(),
strides.data(),
t.value_or(nb::dtype<T>()));
}
template <typename Lib>
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
switch (a.dtype()) {
case bool_:
return mlx_to_nd_array<Lib, bool>(a);
case uint8:
return mlx_to_nd_array<Lib, uint8_t>(a);
case uint16:
return mlx_to_nd_array<Lib, uint16_t>(a);
case uint32:
return mlx_to_nd_array<Lib, uint32_t>(a);
case uint64:
return mlx_to_nd_array<Lib, uint64_t>(a);
case int8:
return mlx_to_nd_array<Lib, int8_t>(a);
case int16:
return mlx_to_nd_array<Lib, int16_t>(a);
case int32:
return mlx_to_nd_array<Lib, int32_t>(a);
case int64:
return mlx_to_nd_array<Lib, int64_t>(a);
case float16:
return mlx_to_nd_array<Lib, float16_t>(a);
case bfloat16:
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
case float32:
return mlx_to_nd_array<Lib, float>(a);
case complex64:
return mlx_to_nd_array<Lib, std::complex<float>>(a);
}
}
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
return mlx_to_nd_array<nb::numpy>(a);
}

16
python/src/convert.h Normal file
View File

@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
#include <optional>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include "mlx/array.h"
namespace nb = nanobind;
using namespace mlx::core;
array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype);
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);

View File

@ -1,32 +1,34 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_device(py::module_& m) {
auto device_class = py::class_<Device>(
void init_device(nb::module_& m) {
auto device_class = nb::class_<Device>(
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
py::enum_<Device::DeviceType>(m, "DeviceType")
nb::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const Device::DeviceType& d1, const Device& d2) {
return d1 == d2;
},
py::prepend());
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
if (!nb::isinstance<Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<Device>(other);
});
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_readonly("type", &Device::type)
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_ro("type", &Device::type)
.def(
"__repr__",
[](const Device& d) {
@ -34,11 +36,15 @@ void init_device(py::module_& m) {
os << d;
return os.str();
})
.def("__eq__", [](const Device& d1, const Device& d2) {
return d1 == d2;
.def("__eq__", [](const Device& d, const nb::object& other) {
if (!nb::isinstance<Device>(other) &&
!nb::isinstance<Device::DeviceType>(other)) {
return false;
}
return d == nb::cast<Device>(other);
});
py::implicitly_convertible<Device::DeviceType, Device>();
nb::implicitly_convertible<Device::DeviceType, Device>();
m.def(
"default_device",

View File

@ -1,20 +1,17 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "mlx/fast.h"
#include "mlx/ops.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_extensions(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
void init_fast(nb::module_& parent_module) {
auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
@ -31,15 +28,15 @@ void init_extensions(py::module_& parent_module) {
},
"a"_a,
"dims"_a,
py::kw_only(),
nb::kw_only(),
"traditional"_a,
"base"_a,
"scale"_a,
"offset"_a,
"stream"_a = none,
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array
Apply rotary positional encoding to the input.
Args:
@ -70,20 +67,25 @@ void init_extensions(py::module_& parent_module) {
"q"_a,
"k"_a,
"v"_a,
py::kw_only(),
nb::kw_only(),
"scale"_a,
"mask"_a = none,
"stream"_a = none,
"mask"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
Supports:
* [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245)
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
Note: The softmax operation is performed in ``float32`` regardless of
input precision.
Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
Args:
q (array): Input query array.
@ -94,6 +96,5 @@ void init_extensions(py::module_& parent_module) {
Returns:
array: The output array.
)pbdoc");
}

View File

@ -1,19 +1,20 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "python/src/utils.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <numeric>
#include "mlx/fft.h"
#include "mlx/ops.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_fft(py::module_& parent_module) {
void init_fft(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"fft", "mlx.core.fft: Fast Fourier Transforms.");
m.def(
@ -29,9 +30,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"n"_a = none,
"n"_a = nb::none(),
"axis"_a = -1,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
One dimensional discrete Fourier Transform.
@ -59,9 +60,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"n"_a = none,
"n"_a = nb::none(),
"axis"_a = -1,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
One dimensional inverse discrete Fourier Transform.
@ -95,9 +96,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional discrete Fourier Transform.
@ -132,9 +133,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional inverse discrete Fourier Transform.
@ -169,9 +170,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional discrete Fourier Transform.
@ -207,9 +208,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional inverse discrete Fourier Transform.
@ -239,9 +240,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"n"_a = none,
"n"_a = nb::none(),
"axis"_a = -1,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
One dimensional discrete Fourier Transform on a real input.
@ -274,9 +275,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"n"_a = none,
"n"_a = nb::none(),
"axis"_a = -1,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfft`.
@ -314,9 +315,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = nb::none(),
R"pbdoc(
Two dimensional real discrete Fourier Transform.
@ -357,9 +358,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = std::vector<int>{-2, -1},
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a.none() = std::vector<int>{-2, -1},
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfft2`.
@ -400,9 +401,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
n-dimensional real discrete Fourier Transform.
@ -443,9 +444,9 @@ void init_fft(py::module_& parent_module) {
}
},
"a"_a,
"s"_a = none,
"axes"_a = none,
"stream"_a = none,
"s"_a = nb::none(),
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`rfftn`.

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
#include <sstream>
@ -7,19 +7,19 @@
#include "mlx/ops.h"
bool is_none_slice(const py::slice& in_slice) {
bool is_none_slice(const nb::slice& in_slice) {
return (
py::getattr(in_slice, "start").is_none() &&
py::getattr(in_slice, "stop").is_none() &&
py::getattr(in_slice, "step").is_none());
nb::getattr(in_slice, "start").is_none() &&
nb::getattr(in_slice, "stop").is_none() &&
nb::getattr(in_slice, "step").is_none());
}
int get_slice_int(py::object obj, int default_val) {
int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!py::isinstance<py::int_>(obj)) {
if (!nb::isinstance<nb::int_>(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return py::cast<int>(py::cast<py::int_>(obj));
return nb::cast<int>(nb::cast<nb::int_>(obj));
}
return default_val;
}
@ -28,7 +28,7 @@ void get_slice_params(
int& starts,
int& ends,
int& strides,
const py::slice& in_slice,
const nb::slice& in_slice,
int axis_size) {
// Following numpy's convention
// Assume n is the number of elements in the dimension being sliced.
@ -36,26 +36,26 @@ void get_slice_params(
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
// k < 0 . If k is not given it defaults to 1
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
starts = get_slice_int(
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int(
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
}
array get_int_index(py::object idx, int axis_size) {
int idx_ = py::cast<int>(idx);
array get_int_index(nb::object idx, int axis_size) {
int idx_ = nb::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32);
}
bool is_valid_index_type(const py::object& obj) {
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
}
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
return take(src, indices, 0);
}
array mlx_get_item_int(const array& src, const py::int_& idx) {
array mlx_get_item_int(const array& src, const nb::int_& idx) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) {
array mlx_gather_nd(
array src,
const std::vector<py::object>& indices,
const std::vector<nb::object>& indices,
bool gather_first,
int& max_dims) {
max_dims = 0;
@ -117,9 +117,10 @@ array mlx_gather_nd(
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<py::slice>(idx)) {
if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i));
get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
// Handle negative indices
start = (start < 0) ? start + src.shape(i) : start;
@ -128,10 +129,10 @@ array mlx_gather_nd(
gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++;
is_slice[i] = true;
} else if (py::isinstance<py::int_>(idx)) {
} else if (nb::isinstance<nb::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (py::isinstance<array>(idx)) {
auto arr = py::cast<array>(idx);
} else if (nb::isinstance<array>(idx)) {
auto arr = nb::cast<array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr);
}
@ -185,7 +186,7 @@ array mlx_gather_nd(
return src;
}
array mlx_get_item_nd(array src, const py::tuple& entries) {
array mlx_get_item_nd(array src, const nb::tuple& entries) {
// No indices make this a noop
if (entries.size() == 0) {
return src;
@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// 3. Calculate the remaining slices and reshapes
// Ellipsis handling
std::vector<py::object> indices;
std::vector<nb::object> indices;
{
int non_none_indices_before = 0;
int non_none_indices_after = 0;
std::vector<py::object> r_indices;
std::vector<nb::object> r_indices;
int i = 0;
for (; i < entries.size(); i++) {
auto idx = entries[i];
@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (!py::ellipsis().is(idx)) {
if (!nb::ellipsis().is(idx)) {
indices.push_back(idx);
non_none_indices_before += !idx.is_none();
} else {
@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (py::ellipsis().is(idx)) {
if (nb::ellipsis().is(idx)) {
throw std::invalid_argument(
"An index can only have a single ellipsis (...)");
}
@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (int axis = non_none_indices_before;
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.push_back(py::slice(0, src.shape(axis), 1));
indices.push_back(nb::slice(0, src.shape(axis), 1));
}
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
}
@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
//
// Check whether we have arrays or integer indices and delegate to gather_nd
// after removing the slices at the end and all Nones.
std::vector<py::object> remaining_indices;
std::vector<nb::object> remaining_indices;
bool have_array = false;
{
// First check whether the results of gather are going to be 1st or
@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (have_array && have_non_array) {
gather_first = true;
break;
@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
}
}
std::vector<py::object> gather_indices;
std::vector<nb::object> gather_indices;
for (int i = 0; i <= last_array; i++) {
auto& idx = indices[i];
if (!idx.is_none()) {
@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
if (gather_first) {
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = 0; i < last_array; i++) {
auto& idx = indices[i];
if (idx.is_none()) {
remaining_indices.push_back(indices[i]);
} else if (py::isinstance<py::slice>(idx)) {
} else if (nb::isinstance<nb::slice>(idx)) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = last_array + 1; i < indices.size(); i++) {
@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
} else {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
get_slice_params(
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
starts[axis],
ends[axis],
strides[axis],
nb::cast<nb::slice>(idx),
ends[axis]);
axis++;
}
}
@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
return src;
}
array mlx_get_item(const array& src, const py::object& obj) {
if (py::isinstance<py::slice>(obj)) {
return mlx_get_item_slice(src, obj);
} else if (py::isinstance<array>(obj)) {
return mlx_get_item_array(src, py::cast<array>(obj));
} else if (py::isinstance<py::int_>(obj)) {
return mlx_get_item_int(src, obj);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_get_item_nd(src, obj);
array mlx_get_item(const array& src, const nb::object& obj) {
if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<array>(obj)) {
return mlx_get_item_array(src, nb::cast<array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src,
const py::int_& idx,
const nb::int_& idx,
const array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
@ -446,7 +453,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src,
const py::slice& in_slice,
const nb::slice& in_slice,
const array& update) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
@ -478,9 +485,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src,
const py::tuple& entries,
const nb::tuple& entries,
const array& update) {
std::vector<py::object> indices;
std::vector<nb::object> indices;
int non_none_indices = 0;
// Expand ellipses into a series of ':' slices
@ -494,7 +501,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
} else if (!py::ellipsis().is(idx)) {
} else if (!nb::ellipsis().is(idx)) {
if (!has_ellipsis) {
indices_before++;
non_none_indices_before += !idx.is_none();
@ -514,7 +521,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.insert(
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
}
non_none_indices = src.ndim();
} else {
@ -549,15 +556,15 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
bool have_array = false;
bool have_non_array = false;
for (auto& idx : indices) {
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
have_non_array = have_array;
num_slices++;
} else if (py::isinstance<array>(idx)) {
} else if (nb::isinstance<array>(idx)) {
have_array = true;
if (have_array && have_non_array) {
arrays_first = true;
}
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
num_arrays++;
}
}
@ -569,10 +576,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
int ax = 0;
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (py::isinstance<py::slice>(pyidx)) {
if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride;
auto axis_size = src.shape(ax++);
get_slice_params(start, end, stride, pyidx, axis_size);
get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
// Handle negative indices
start = (start < 0) ? start + axis_size : start;
@ -584,13 +592,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
slice_num++;
idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape));
} else if (py::isinstance<py::int_>(pyidx)) {
} else if (nb::isinstance<nb::int_>(pyidx)) {
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
} else if (pyidx.is_none()) {
slice_num++;
} else if (py::isinstance<array>(pyidx)) {
} else if (nb::isinstance<array>(pyidx)) {
ax++;
auto idx = py::cast<array>(pyidx);
auto idx = nb::cast<array>(pyidx);
std::vector<int> idx_shape;
if (!arrays_first) {
idx_shape.insert(idx_shape.end(), slice_num, 1);
@ -629,24 +637,24 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
std::tuple<std::vector<array>, array, std::vector<int>>
mlx_compute_scatter_args(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (py::isinstance<py::slice>(obj)) {
return mlx_scatter_args_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_scatter_args_nd(src, obj, vals);
if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);
@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
array mlx_add_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@ -670,7 +678,7 @@ array mlx_add_item(
array mlx_subtract_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@ -682,7 +690,7 @@ array mlx_subtract_item(
array mlx_multiply_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@ -694,7 +702,7 @@ array mlx_multiply_item(
array mlx_divide_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@ -706,7 +714,7 @@ array mlx_divide_item(
array mlx_maximum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@ -718,7 +726,7 @@ array mlx_maximum_item(
array mlx_minimum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {

View File

@ -1,38 +1,38 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include "mlx/array.h"
#include "python/src/utils.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlx::core;
array mlx_get_item(const array& src, const py::object& obj);
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
array mlx_get_item(const array& src, const nb::object& obj);
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
array mlx_add_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_subtract_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_multiply_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_divide_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_maximum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);
array mlx_minimum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v);

View File

@ -1,32 +1,29 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <variant>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/linalg.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
namespace {
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
const auto result = svd(a, s);
return py::make_tuple(result.at(0), result.at(1), result.at(2));
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
} // namespace
void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
void init_linalg(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines.");
@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) {
return norm(a, ord, axis, keepdims, stream);
}
},
"a"_a,
py::pos_only(),
"ord"_a = none,
"axis"_a = none,
nb::arg(),
"ord"_a = nb::none(),
"axis"_a = nb::none(),
"keepdims"_a = false,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm.
This function computes vector or matrix norms depending on the value of
@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) {
"qr",
&qr,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
R"pbdoc(
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
The QR factorization of the input matrix.
This function supports arrays with at least 2 dimensions. The matrices
@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) {
"svd",
&svd_helper,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
R"pbdoc(
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
The Singular Value Decomposition (SVD) of the input matrix.
This function supports arrays with at least 2 dimensions. When the input
@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) {
"inv",
&inv,
"a"_a,
py::kw_only(),
"stream"_a = none,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array
Compute the inverse of a square matrix.
This function supports arrays with at least 2 dimensions. When the input

View File

@ -1,8 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/stl/vector.h>
#include <cstring>
#include <fstream>
#include <stdexcept>
@ -16,39 +14,39 @@
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
///////////////////////////////////////////////////////////////////////////////
bool is_istream_object(const py::object& file) {
return py::hasattr(file, "readinto") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
bool is_istream_object(const nb::object& file) {
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
}
bool is_ostream_object(const py::object& file) {
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
bool is_ostream_object(const nb::object& file) {
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
}
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
if (is_istream_object(file)) {
auto st_pos = file.attr("tell")();
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
file.attr("seek")(st_pos, 0);
return r;
}
return zipfile.attr("is_zipfile")(file).cast<bool>();
return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
}
class ZipFileWrapper {
public:
ZipFileWrapper(
const py::module_& zipfile,
const py::object& file,
const nb::module_& zipfile,
const nb::object& file,
char mode = 'r',
int compression = 0)
: zipfile_module_(zipfile),
@ -63,10 +61,10 @@ class ZipFileWrapper {
close_func_(zipfile_object_.attr("close")) {}
std::vector<std::string> namelist() const {
return files_list_.cast<std::vector<std::string>>();
return nb::cast<std::vector<std::string>>(files_list_);
}
py::object open(const std::string& key, char mode = 'r') {
nb::object open(const std::string& key, char mode = 'r') {
// Following numpy :
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
if (mode == 'w') {
@ -76,12 +74,12 @@ class ZipFileWrapper {
}
private:
py::module_ zipfile_module_;
py::object zipfile_object_;
py::list files_list_;
py::object open_func_;
py::object read_func_;
py::object close_func_;
nb::module_ zipfile_module_;
nb::object zipfile_object_;
nb::list files_list_;
nb::object open_func_;
nb::object read_func_;
nb::object close_func_;
};
///////////////////////////////////////////////////////////////////////////////
@ -90,14 +88,14 @@ class ZipFileWrapper {
class PyFileReader : public io::Reader {
public:
PyFileReader(py::object file)
PyFileReader(nb::object file)
: pyistream_(file),
readinto_func_(file.attr("readinto")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileReader() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
pyistream_.release().dec_ref();
readinto_func_.release().dec_ref();
@ -108,8 +106,8 @@ class PyFileReader : public io::Reader {
bool is_open() const override {
bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.attr("closed").cast<bool>();
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyistream_.attr("closed"));
}
return out;
}
@ -117,7 +115,7 @@ class PyFileReader : public io::Reader {
bool good() const override {
bool out;
{
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
out = !pyistream_.is_none();
}
return out;
@ -126,25 +124,24 @@ class PyFileReader : public io::Reader {
size_t tell() const override {
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void read(char* data, size_t n) override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
py::object bytes_read =
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream");
}
}
@ -154,23 +151,23 @@ class PyFileReader : public io::Reader {
}
private:
py::object pyistream_;
py::object readinto_func_;
py::object seek_func_;
py::object tell_func_;
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;
nb::object tell_func_;
};
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return load_safetensors(py::cast<std::string>(file), s);
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
arr.eval();
}
@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string");
}
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s);
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
nb::object file,
StreamOrDevice s) {
bool own_file = py::isinstance<py::str>(file);
bool own_file = nb::isinstance<nb::str>(file);
py::module_ zipfile = py::module_::import("zipfile");
nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be "
@ -208,7 +205,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
nb::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
@ -224,7 +221,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
// If we don't own the stream and it was passed to us, eval immediately
if (!own_file) {
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
for (auto& [key, arr] : array_dict) {
arr.eval();
}
@ -233,14 +230,14 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict;
}
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return load(py::cast<std::string>(file), s);
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
arr.eval();
}
return arr;
@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
}
LoadOutputTypes mlx_load_helper(
py::object file,
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (py::isinstance<py::str>(file)) {
fname = py::cast<std::string>(file);
if (nb::isinstance<nb::str>(file)) {
fname = nb::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>();
fname = nb::cast<std::string>(file.attr("name"));
} else {
throw std::invalid_argument(
"[load] Input must be a file-like object opened in binary mode, or string");
@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper(
class PyFileWriter : public io::Writer {
public:
PyFileWriter(py::object file)
PyFileWriter(nb::object file)
: pyostream_(file),
write_func_(file.attr("write")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileWriter() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
pyostream_.release().dec_ref();
write_func_.release().dec_ref();
@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer {
bool is_open() const override {
bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.attr("closed").cast<bool>();
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyostream_.attr("closed"));
}
return out;
}
@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer {
bool good() const override {
bool out;
{
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
out = !pyostream_.is_none();
}
return out;
@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer {
size_t tell() const override {
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void write(const char* data, size_t n) override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
py::object bytes_written =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
auto memview =
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
nb::object bytes_written = write_func_(nb::handle(memview));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream");
}
}
@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer {
}
private:
py::object pyostream_;
py::object write_func_;
py::object seek_func_;
py::object tell_func_;
nb::object pyostream_;
nb::object write_func_;
nb::object seek_func_;
nb::object tell_func_;
};
void mlx_save_helper(py::object file, array a) {
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a);
void mlx_save_helper(nb::object file, array a) {
if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
save(writer, a);
}
@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) {
}
void mlx_savez_helper(
py::object file_,
py::args args,
const py::kwargs& kwargs,
nb::object file_,
nb::args args,
const nb::kwargs& kwargs,
bool compressed) {
// Add .npz to the end of the filename if not already there
py::object file = file_;
nb::object file = file_;
if (py::isinstance<py::str>(file_)) {
std::string fname = file_.cast<std::string>();
if (nb::isinstance<nb::str>(file_)) {
std::string fname = nb::cast<std::string>(file_);
// Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
fname += ".npz";
file = py::str(fname);
file = nb::cast(fname);
}
// Collect args and kwargs
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
auto arrays_list = args.cast<std::vector<array>>();
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = nb::cast<std::vector<array>>(args);
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
@ -426,9 +424,9 @@ void mlx_savez_helper(
}
// Create python ZipFile object depending on compression
py::module_ zipfile = py::module_::import("zipfile");
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
: zipfile.attr("ZIP_STORED").cast<int>();
nb::module_ zipfile = nb::module_::import_("zipfile");
int compression = nb::cast<int>(
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
char mode = 'w';
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
@ -438,7 +436,7 @@ void mlx_savez_helper(
auto py_ostream = zipfile_object.open(fname, 'w');
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release nogil;
nb::gil_scoped_release nogil;
save(writer, a);
}
}
@ -447,31 +445,31 @@ void mlx_savez_helper(
}
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<py::dict> m) {
nb::object file,
nb::dict d,
std::optional<nb::dict> m) {
std::unordered_map<std::string, std::string> metadata_map;
if (m) {
try {
metadata_map =
m.value().cast<std::unordered_map<std::string, std::string>>();
} catch (const py::cast_error& e) {
nb::cast<std::unordered_map<std::string, std::string>>(m.value());
} catch (const nb::cast_error& e) {
throw std::invalid_argument(
"[save_safetensors] Metadata must be a dictionary with string keys and values");
}
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
if (nb::isinstance<nb::str>(file)) {
{
py::gil_scoped_release nogil;
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release nogil;
nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map);
}
} else {
@ -481,22 +479,22 @@ void mlx_save_safetensor_helper(
}
void mlx_save_gguf_helper(
py::object file,
py::dict a,
std::optional<py::dict> m) {
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
nb::object file,
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (m) {
auto metadata_map =
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
{
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else {
{
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map);
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map);
}
}
} else {

View File

@ -1,15 +1,20 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <optional>
#include <string>
#include <unordered_map>
#include <variant>
#include "mlx/io.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlx::core;
using LoadOutputTypes = std::variant<
@ -18,27 +23,27 @@ using LoadOutputTypes = std::variant<
SafetensorsLoad,
GGUFLoad>;
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<py::dict> m);
nb::object file,
nb::dict d,
std::optional<nb::dict> m);
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
void mlx_save_gguf_helper(
py::object file,
py::dict d,
std::optional<py::dict> m);
nb::object file,
nb::dict d,
std::optional<nb::dict> m);
LoadOutputTypes mlx_load_helper(
py::object file,
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s);
void mlx_save_helper(py::object file, array a);
void mlx_save_helper(nb::object file, array a);
void mlx_savez_helper(
py::object file,
py::args args,
const py::kwargs& kwargs,
nb::object file,
nb::args args,
const nb::kwargs& kwargs,
bool compressed = false);

View File

@ -1,16 +1,15 @@
// Copyright © 2023 Apple Inc.
#include <pybind11/pybind11.h>
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/metal.h"
#include <nanobind/nanobind.h>
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def(
"is_available",
&metal::is_available,
@ -48,7 +47,7 @@ void init_metal(py::module_& m) {
"set_memory_limit",
&metal::set_memory_limit,
"limit"_a,
py::kw_only(),
nb::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.

View File

@ -1,30 +1,30 @@
// Copyright © 2023 Apple Inc.
// Conbright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
namespace py = pybind11;
namespace nb = nanobind;
void init_array(py::module_&);
void init_device(py::module_&);
void init_stream(py::module_&);
void init_metal(py::module_&);
void init_ops(py::module_&);
void init_transforms(py::module_&);
void init_random(py::module_&);
void init_fft(py::module_&);
void init_linalg(py::module_&);
void init_constants(py::module_&);
void init_extensions(py::module_&);
void init_utils(py::module_&);
void init_array(nb::module_&);
void init_device(nb::module_&);
void init_stream(nb::module_&);
void init_metal(nb::module_&);
void init_ops(nb::module_&);
void init_transforms(nb::module_&);
void init_random(nb::module_&);
void init_fft(nb::module_&);
void init_linalg(nb::module_&);
void init_constants(nb::module_&);
void init_fast(nb::module_&);
PYBIND11_MODULE(core, m) {
NB_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
py::module_::import("mlx._os_warning");
auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
nb::module_::import_("mlx._os_warning");
nb::set_leak_warnings(false);
init_device(m);
init_stream(m);
@ -36,8 +36,7 @@ PYBIND11_MODULE(core, m) {
init_fft(m);
init_linalg(m);
init_constants(m);
init_extensions(m);
init_utils(m);
init_fast(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}

File diff suppressed because it is too large Load Diff

View File

@ -1,60 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// A patch to get float16_t to work with pybind11 numpy arrays
// Derived from:
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
#include <pybind11/numpy.h>
namespace pybind11::detail {
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
// Kinda following:
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
template <>
struct npy_format_descriptor<float16_t> {
static constexpr auto name = _("float16");
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
};
template <>
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
static constexpr auto name = _("float16");
};
} // namespace pybind11::detail

View File

@ -1,7 +1,10 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h"
@ -9,8 +12,8 @@
#include "mlx/ops.h"
#include "mlx/random.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using namespace mlx::core::random;
@ -25,22 +28,22 @@ class PyKeySequence {
}
array next() {
auto out = split(py::cast<array>(state_[0]));
auto out = split(nb::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
nb::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
nb::list state_;
};
PyKeySequence& default_key() {
@ -54,7 +57,7 @@ PyKeySequence& default_key() {
return ks;
}
void init_random(py::module_& parent_module) {
void init_random(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) {
)pbdoc");
m.def(
"split",
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
"key"_a,
"num"_a = 2,
"stream"_a = none,
"stream"_a = nb::none(),
R"pbdoc(
Split a PRNG key into sub keys.
@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) {
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate uniformly distributed random numbers.
@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) {
return normal(shape, type.value_or(float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"dtype"_a.none() = float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
"key"_a = none,
"stream"_a = none,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate normally distributed random numbers.
@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) {
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"dtype"_a = int32,
"key"_a = none,
"stream"_a = none,
"dtype"_a.none() = int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate random integers from the given interval.
@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) {
}
},
"p"_a = 0.5,
"shape"_a = none,
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate Bernoulli random values.
@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) {
},
"lower"_a,
"upper"_a,
"shape"_a = none,
"dtype"_a = std::optional{float32},
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Generate values from a truncated normal distribution.
@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) {
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"stream"_a = none,
"key"_a = none,
"dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(),
R"pbdoc(
Sample from the standard Gumbel distribution.
@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) {
},
"logits"_a,
"axis"_a = -1,
"shape"_a = none,
"num_samples"_a = none,
"key"_a = none,
"stream"_a = none,
"shape"_a = nb::none(),
"num_samples"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Sample from a categorical distribution.
@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) {
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
}

View File

@ -1,25 +1,54 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include <pybind11/pybind11.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/variant.h>
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_stream(py::module_& m) {
py::class_<Stream>(
// Create the StreamContext on enter and delete on exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_stream(nb::module_& m) {
nb::class_<Stream>(
m,
"Stream",
R"pbdoc(
A stream for running operations on a given device.
)pbdoc")
.def(py::init<int, Device>(), "index"_a, "device"_a)
.def_readonly("device", &Stream::device)
.def(nb::init<int, Device>(), "index"_a, "device"_a)
.def_ro("device", &Stream::device)
.def(
"__repr__",
[](const Stream& s) {
@ -31,7 +60,7 @@ void init_stream(py::module_& m) {
return s1 == s2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
nb::implicitly_convertible<Device::DeviceType, Device>();
m.def(
"default_stream",
@ -56,4 +85,48 @@ void init_stream(py::module_& m) {
&new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(nb::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type,
const std::optional<nb::object>& exc_value,
const std::optional<nb::object>& traceback) { scm.exit(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none());
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
}

View File

@ -1,6 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <fstream>
#include <numeric>
@ -13,13 +18,17 @@
#include "mlx/transforms_impl.h"
#include "python/src/trees.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
inline std::string type_name_str(const nb::handle& o) {
return nb::cast<std::string>(nb::type_name(o.type()));
}
template <typename T>
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
std::vector<T> vals;
@ -49,7 +58,7 @@ auto validate_argnums_argnames(
}
auto py_value_and_grad(
const py::function& fun,
const nb::callable& fun,
std::vector<int> argnums,
std::vector<std::string> argnames,
const std::string& error_msg_tag,
@ -71,7 +80,7 @@ auto py_value_and_grad(
}
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
const py::args& args, const py::kwargs& kwargs) {
const nb::args& args, const nb::kwargs& kwargs) {
// Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) {
std::ostringstream msg;
@ -89,7 +98,7 @@ auto py_value_and_grad(
<< "' because the function is called with the "
<< "following keyword arguments {";
for (auto item : kwargs) {
msg << item.first.cast<std::string>() << ",";
msg << nb::cast<std::string>(item.first) << ",";
}
msg << "}";
throw std::invalid_argument(msg.str());
@ -115,7 +124,7 @@ auto py_value_and_grad(
// value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
py::object py_value_out;
nb::object py_value_out;
auto value_and_grads = value_and_grad(
[&fun,
&args,
@ -127,15 +136,15 @@ auto py_value_and_grad(
&error_msg_tag,
scalar_func_only](const std::vector<array>& a) {
// Copy the arguments
py::args args_cpy = py::tuple(args.size());
py::kwargs kwargs_cpy = py::kwargs();
nb::list args_cpy;
nb::kwargs kwargs_cpy = nb::kwargs();
int j = 0;
for (int i = 0; i < args.size(); ++i) {
if (j < argnums.size() && i == argnums[j]) {
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
j++;
} else {
args_cpy[i] = args[i];
args_cpy.append(args[i]);
}
}
for (auto& key : argnames) {
@ -154,25 +163,25 @@ auto py_value_and_grad(
py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function
if (!py::isinstance<array>(py_value_out)) {
if (!nb::isinstance<array>(py_value_out)) {
if (scalar_func_only) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be a "
<< "scalar array; but " << py_value_out.get_type()
<< "scalar array; but " << type_name_str(py_value_out)
<< " was returned.";
throw std::invalid_argument(msg.str());
}
if (!py::isinstance<py::tuple>(py_value_out)) {
if (!nb::isinstance<nb::tuple>(py_value_out)) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
<< py_value_out.get_type() << " was returned.";
<< type_name_str(py_value_out) << " was returned.";
throw std::invalid_argument(msg.str());
}
py::tuple ret = py::cast<py::tuple>(py_value_out);
nb::tuple ret = nb::cast<nb::tuple>(py_value_out);
if (ret.size() == 0) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
@ -182,14 +191,14 @@ auto py_value_and_grad(
<< "we got an empty tuple.";
throw std::invalid_argument(msg.str());
}
if (!py::isinstance<array>(ret[0])) {
if (!nb::isinstance<array>(ret[0])) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
<< "was a tuple with the first value being of type "
<< ret[0].get_type() << " .";
<< type_name_str(ret[0]) << " .";
throw std::invalid_argument(msg.str());
}
}
@ -212,61 +221,60 @@ auto py_value_and_grad(
// In case 2 we return a tuple of the above.
// In case 3 we return a tuple containing a tuple and dict (sth like
// (tuple(), dict(x=mx.array(5))) ).
py::object positional_grads;
py::object keyword_grads;
py::object py_grads;
nb::object positional_grads;
nb::object keyword_grads;
nb::object py_grads;
// Collect the gradients for the positional arguments
if (argnums.size() == 1) {
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
} else if (argnums.size() > 1) {
py::tuple grads_(argnums.size());
nb::list grads_;
for (int i = 0; i < argnums.size(); i++) {
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
}
positional_grads = py::cast<py::object>(grads_);
positional_grads = nb::tuple(grads_);
} else {
positional_grads = py::none();
positional_grads = nb::none();
}
// No keyword argument gradients so return the tuple of gradients
if (argnames.size() == 0) {
py_grads = positional_grads;
} else {
py::dict grads_;
nb::dict grads_;
for (int i = 0; i < argnames.size(); i++) {
auto& k = argnames[i];
grads_[k.c_str()] = tree_unflatten(
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
}
keyword_grads = py::cast<py::object>(grads_);
keyword_grads = grads_;
py_grads =
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
py_grads = nb::make_tuple(positional_grads, keyword_grads);
}
// Put the values back in the container
py::object return_value = tree_unflatten(py_value_out, value);
nb::object return_value = tree_unflatten(py_value_out, value);
return std::make_pair(return_value, py_grads);
};
}
auto py_vmap(
const py::function& fun,
const py::object& in_axes,
const py::object& out_axes) {
return [fun, in_axes, out_axes](const py::args& args) {
auto axes_to_flat_tree = [](const py::object& tree,
const py::object& axes) {
const nb::callable& fun,
const nb::object& in_axes,
const nb::object& out_axes) {
return [fun, in_axes, out_axes](const nb::args& args) {
auto axes_to_flat_tree = [](const nb::object& tree,
const nb::object& axes) {
auto tree_axes = tree_map(
{tree, axes},
[](const std::vector<py::object>& inputs) { return inputs[1]; });
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
std::vector<int> flat_axes;
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
if (obj.is_none()) {
flat_axes.push_back(-1);
} else if (py::isinstance<py::int_>(obj)) {
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
} else if (nb::isinstance<nb::int_>(obj)) {
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
} else {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
@ -280,7 +288,7 @@ auto py_vmap(
// py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
py::object py_outputs;
nb::object py_outputs;
auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
@ -305,24 +313,24 @@ auto py_vmap(
};
}
std::unordered_map<size_t, py::object>& tree_cache() {
std::unordered_map<size_t, nb::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, py::object> tree_cache_;
static std::unordered_map<size_t, nb::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
py::function fun;
nb::callable fun;
size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
nb::object captured_inputs;
nb::object captured_outputs;
bool shapeless;
size_t num_outputs{0};
mutable size_t num_outputs{0};
PyCompiledFun(
const py::function& fun,
py::object inputs,
py::object outputs,
const nb::callable& fun,
nb::object inputs,
nb::object outputs,
bool shapeless)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
@ -342,7 +350,7 @@ struct PyCompiledFun {
num_outputs = other.num_outputs;
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
// Flat array inputs
std::vector<array> inputs;
@ -358,45 +366,45 @@ struct PyCompiledFun {
constexpr uint64_t dict_identifier = 18446744073709551521UL;
// Flatten the tree with hashed constants and structure
std::function<void(py::handle)> recurse;
recurse = [&](py::handle obj) {
if (py::isinstance<py::list>(obj)) {
auto l = py::cast<py::list>(obj);
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle obj) {
if (nb::isinstance<nb::list>(obj)) {
auto l = nb::cast<nb::list>(obj);
constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) {
recurse(l[i]);
}
} else if (py::isinstance<py::tuple>(obj)) {
auto l = py::cast<py::tuple>(obj);
} else if (nb::isinstance<nb::tuple>(obj)) {
auto l = nb::cast<nb::tuple>(obj);
constants.push_back(list_identifier);
for (auto item : obj) {
recurse(item);
}
} else if (py::isinstance<py::dict>(obj)) {
auto d = py::cast<py::dict>(obj);
} else if (nb::isinstance<nb::dict>(obj)) {
auto d = nb::cast<nb::dict>(obj);
constants.push_back(dict_identifier);
for (auto item : d) {
auto r = py::hash(item.first);
auto r = item.first.attr("__hash__");
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
recurse(item.second);
}
} else if (py::isinstance<array>(obj)) {
inputs.push_back(py::cast<array>(obj));
} else if (nb::isinstance<array>(obj)) {
inputs.push_back(nb::cast<array>(obj));
constants.push_back(array_identifier);
} else if (py::isinstance<py::str>(obj)) {
auto r = py::hash(obj);
} else if (nb::isinstance<nb::str>(obj)) {
auto r = obj.attr("__hash__");
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::int_>(obj)) {
auto r = obj.cast<int64_t>();
} else if (nb::isinstance<nb::int_>(obj)) {
auto r = nb::cast<int64_t>(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::float_>(obj)) {
auto r = obj.cast<double>();
} else if (nb::isinstance<nb::float_>(obj)) {
auto r = nb::cast<double>(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else {
std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, or strings), but received "
<< "type " << obj.get_type() << ".";
<< "type " << type_name_str(obj) << ".";
throw std::invalid_argument(msg.str());
}
};
@ -404,13 +412,12 @@ struct PyCompiledFun {
recurse(args);
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
if (!py::isinstance<py::none>(captured_inputs)) {
if (!captured_inputs.is_none()) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
@ -425,7 +432,7 @@ struct PyCompiledFun {
tree_cache().insert({fun_id, py_outputs});
num_outputs = outputs.size();
if (!py::isinstance<py::none>(captured_outputs)) {
if (!captured_outputs.is_none()) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
outputs.end(),
@ -434,13 +441,13 @@ struct PyCompiledFun {
}
// Replace tracers with originals in captured inputs
if (!py::isinstance<py::none>(captured_inputs)) {
if (!captured_inputs.is_none()) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
};
if (!py::isinstance<py::none>(captured_inputs)) {
if (!captured_inputs.is_none()) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
inputs.end(),
@ -451,7 +458,7 @@ struct PyCompiledFun {
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
if (!captured_outputs.is_none()) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
@ -459,12 +466,16 @@ struct PyCompiledFun {
}
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
nb::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_from_structure(py_outputs, outputs);
}
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);
};
~PyCompiledFun() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
@ -476,35 +487,35 @@ struct PyCompiledFun {
class PyCheckpointedFun {
public:
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
}
struct InnerFunction {
py::object fun_;
py::object args_structure_;
std::weak_ptr<py::object> output_structure_;
nb::object fun_;
nb::object args_structure_;
std::weak_ptr<nb::object> output_structure_;
InnerFunction(
py::object fun,
py::object args_structure,
std::weak_ptr<py::object> output_structure)
nb::object fun,
nb::object args_structure,
std::weak_ptr<nb::object> output_structure)
: fun_(std::move(fun)),
args_structure_(std::move(args_structure)),
output_structure_(output_structure) {}
~InnerFunction() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
auto args = py::cast<py::tuple>(
auto args = nb::cast<nb::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
@ -515,9 +526,9 @@ class PyCheckpointedFun {
}
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto output_structure = std::make_shared<py::object>();
auto full_args = py::make_tuple(args, kwargs);
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
auto output_structure = std::make_shared<nb::object>();
auto full_args = nb::make_tuple(args, kwargs);
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
@ -527,26 +538,27 @@ class PyCheckpointedFun {
return tree_unflatten_from_structure(*output_structure, outputs);
}
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);
}
private:
py::function fun_;
nb::callable fun_;
};
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
void init_transforms(nb::module_& m) {
m.def(
"eval",
[](const py::args& args) {
[](const nb::args& args) {
std::vector<array> arrays = tree_flatten(args, false);
{
py::gil_scoped_release nogil;
nb::gil_scoped_release nogil;
eval(arrays);
}
},
nb::arg(),
nb::sig("def eval(*args) -> None"),
R"pbdoc(
eval(*args) -> None
Evaluate an :class:`array` or tree of :class:`array`.
Args:
@ -557,19 +569,15 @@ void init_transforms(py::module_& m) {
)pbdoc");
m.def(
"jvp",
[](const py::function& fun,
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size());
for (int i = 0; i < primals.size(); ++i) {
args[i] = primals[i];
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
} else {
return py::cast<std::vector<array>>(out);
return nb::cast<std::vector<array>>(out);
}
};
return jvp(vfun, primals, tangents);
@ -577,17 +585,16 @@ void init_transforms(py::module_& m) {
"fun"_a,
"primals"_a,
"tangents"_a,
nb::sig(
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
R"pbdoc(
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
Compute the Jacobian-vector product.
This computes the product of the Jacobian of a function ``fun`` evaluated
at ``primals`` with the ``tangents``.
Args:
fun (function): A function which takes a variable number of :class:`array`
fun (callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
@ -601,19 +608,15 @@ void init_transforms(py::module_& m) {
)pbdoc");
m.def(
"vjp",
[](const py::function& fun,
[](const nb::callable& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size());
for (int i = 0; i < primals.size(); ++i) {
args[i] = primals[i];
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
auto out = fun(*nb::cast(primals));
if (nb::isinstance<array>(out)) {
return std::vector<array>{nb::cast<array>(out)};
} else {
return py::cast<std::vector<array>>(out);
return nb::cast<std::vector<array>>(out);
}
};
return vjp(vfun, primals, cotangents);
@ -621,16 +624,16 @@ void init_transforms(py::module_& m) {
"fun"_a,
"primals"_a,
"cotangents"_a,
nb::sig(
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
R"pbdoc(
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
Compute the vector-Jacobian product.
Computes the product of the ``cotangents`` with the Jacobian of a
function ``fun`` evaluated at ``primals``.
Args:
fun (function): A function which takes a variable number of :class:`array`
fun (callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
@ -644,20 +647,20 @@ void init_transforms(py::module_& m) {
)pbdoc");
m.def(
"value_and_grad",
[](const py::function& fun,
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames);
return py::cpp_function(py_value_and_grad(
return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
},
"fun"_a,
"argnums"_a = std::nullopt,
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
R"pbdoc(
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the value and gradient of ``fun``.
The function passed to :func:`value_and_grad` should return either
@ -688,7 +691,7 @@ void init_transforms(py::module_& m) {
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
Args:
fun (function): A function which takes a variable number of
fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array` or a tuple the first element
of which should be a scalar :class:`array`.
@ -702,34 +705,34 @@ void init_transforms(py::module_& m) {
no gradients for keyword arguments by default.
Returns:
function: A function which returns a tuple where the first element
callable: A function which returns a tuple where the first element
is the output of `fun` and the second element is the gradients w.r.t.
the loss.
)pbdoc");
m.def(
"grad",
[](const py::function& fun,
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames);
auto fn =
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
return py::cpp_function(
[fn](const py::args& args, const py::kwargs& kwargs) {
return nb::cpp_function(
[fn](const nb::args& args, const nb::kwargs& kwargs) {
return fn(args, kwargs).second;
});
},
"fun"_a,
"argnums"_a = std::nullopt,
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
R"pbdoc(
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the gradient of ``fun``.
Args:
fun (function): A function which takes a variable number of
fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
@ -742,26 +745,26 @@ void init_transforms(py::module_& m) {
no gradients for keyword arguments by default.
Returns:
function: A function which has the same input arguments as ``fun`` and
callable: A function which has the same input arguments as ``fun`` and
returns the gradient(s).
)pbdoc");
m.def(
"vmap",
[](const py::function& fun,
const py::object& in_axes,
const py::object& out_axes) {
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
[](const nb::callable& fun,
const nb::object& in_axes,
const nb::object& out_axes) {
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
},
"fun"_a,
"in_axes"_a = 0,
"out_axes"_a = 0,
nb::sig(
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
R"pbdoc(
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
Returns a vectorized version of ``fun``.
Args:
fun (function): A function which takes a variable number of
fun (callable): A function which takes a variable number of
:class:`array` or a tree of :class:`array` and returns
a variable number of :class:`array` or a tree of :class:`array`.
in_axes (int, optional): An integer or a valid prefix tree of the
@ -774,16 +777,16 @@ void init_transforms(py::module_& m) {
Defaults to ``0``.
Returns:
function: The vectorized function.
callable: The vectorized function.
)pbdoc");
m.def(
"export_to_dot",
[](py::object file, const py::args& args) {
[](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args);
if (py::isinstance<py::str>(file)) {
std::ofstream out(py::cast<std::string>(file));
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
} else if (py::hasattr(file, "write")) {
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
auto write = file.attr("write");
@ -793,57 +796,50 @@ void init_transforms(py::module_& m) {
"export_to_dot accepts file-like objects or strings to be used as filenames");
}
},
"file"_a);
"file"_a,
"args"_a);
m.def(
"compile",
[](const py::function& fun,
const py::object& inputs,
const py::object& outputs,
[](const nb::callable& fun,
const nb::object& inputs,
const nb::object& outputs,
bool shapeless) {
py::options options;
options.disable_function_signatures();
std::ostringstream doc;
auto name = fun.attr("__name__").cast<std::string>();
doc << name;
// Try to get the name
auto n = fun.attr("__name__");
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
// Try to get the signature
auto inspect = py::module::import("inspect");
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
doc << inspect.attr("signature")(fun)
.attr("__str__")()
.cast<std::string>();
std::ostringstream sig;
sig << "def " << name;
auto inspect = nb::module_::import_("inspect");
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
sig << nb::cast<std::string>(
inspect.attr("signature")(fun).attr("__str__")());
} else {
sig << "(*args, **kwargs)";
}
// Try to get the doc string
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
doc << "\n\n";
auto dstr = d.cast<std::string>();
// Add spaces to match first line indentation with remainder of
// docstring
int i = 0;
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
doc << ' ';
}
doc << dstr;
}
auto doc_str = doc.str();
return py::cpp_function(
auto d = inspect.attr("getdoc")(fun);
std::string doc =
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
py::name(name.c_str()),
py::doc(doc_str.c_str()));
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str());
},
"fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
"inputs"_a = nb::none(),
"outputs"_a = nb::none(),
"shapeless"_a = false,
R"pbdoc(
compile(fun: function) -> function
Returns a compiled function which produces the same output as ``fun``.
Args:
fun (function): A function which takes a variable number of
fun (callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during
@ -864,15 +860,13 @@ void init_transforms(py::module_& m) {
``shapeless`` set to ``True``. Default: ``False``
Returns:
function: A compiled function which has the same input arguments
callable: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(
"disable_compile",
&disable_compile,
R"pbdoc(
disable_compile() -> None
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
@ -880,17 +874,15 @@ void init_transforms(py::module_& m) {
"enable_compile",
&enable_compile,
R"pbdoc(
enable_compile() -> None
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
auto atexit = nb::module_::import_("atexit");
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
}

View File

@ -2,16 +2,16 @@
#include "python/src/trees.h"
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
@ -23,63 +23,63 @@ void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
nb::object tree_map(
const std::vector<nb::object>& trees,
std::function<nb::object(const std::vector<nb::object>&)> transform) {
std::function<nb::object(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
recurse = [&](const std::vector<nb::object>& subtrees) {
if (nb::isinstance<nb::list>(subtrees[0])) {
nb::list l;
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
std::vector<nb::object> items(subtrees.size());
int len = nb::cast<nb::tuple>(subtrees[0]).size();
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
nb::dict d;
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
@ -91,7 +91,7 @@ py::object tree_map(
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
return nb::cast<nb::object>(d);
} else {
return transform(subtrees);
}
@ -99,40 +99,40 @@ py::object tree_map(
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
nb::object tree_map(
nb::object tree,
std::function<nb::object(nb::handle)> transform) {
return tree_map({tree}, [&](std::vector<nb::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
nb::object tree,
std::function<nb::object(nb::handle)> visitor) {
std::function<nb::object(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree)) {
auto l = nb::cast<nb::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
return nb::cast<nb::object>(subtree);
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return nb::cast<nb::object>(d);
} else if (nb::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
return nb::cast<nb::object>(subtree);
}
};
recurse(tree);
@ -141,36 +141,36 @@ void tree_visit_update(
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
void tree_fill(nb::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
nb::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
tree_visit_update(tree, [&](nb::handle node) {
auto arr = nb::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
return nb::cast(it->second);
}
return py::cast(arr);
return nb::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
@ -180,24 +180,24 @@ std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
nb::object tree_unflatten(
nb::object tree,
const std::vector<array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
return tree_map(tree, [&](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
return nb::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
return nb::cast<nb::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
nb::object structure_sentinel() {
static nb::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
sentinel = nb::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
@ -206,19 +206,19 @@ py::object structure_sentinel() {
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
if (nb::isinstance<array>(obj)) {
flat_tree.push_back(nb::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
return nb::cast<nb::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
@ -228,16 +228,16 @@ std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
return tree_map(structure, [&](nb::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
return nb::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
return nb::cast<nb::object>(obj);
}
});
}

View File

@ -1,38 +1,37 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include "mlx/array.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlx::core;
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform);
nb::object tree_map(
const std::vector<nb::object>& trees,
std::function<nb::object(const std::vector<nb::object>&)> transform);
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform);
nb::object tree_map(
nb::object tree,
std::function<nb::object(nb::handle)> transform);
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor);
nb::object tree,
std::function<nb::object(nb::handle)> visitor);
/**
* Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */
void tree_fill(py::object& tree, const std::vector<array>& values);
void tree_fill(nb::object& tree, const std::vector<array>& values);
/**
* Replace all the arrays from the src values with the dst values in the
* tree.
*/
void tree_replace(
py::object& tree,
nb::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst);
@ -40,21 +39,21 @@ void tree_replace(
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<array> tree_flatten(py::object tree, bool strict = true);
std::vector<array> tree_flatten(nb::object tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.
*/
py::object tree_unflatten(
py::object tree,
nb::object tree_unflatten(
nb::object tree,
const std::vector<array>& values,
int index = 0);
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
nb::object tree,
bool strict = true);
py::object tree_unflatten_from_structure(
py::object structure,
nb::object tree_unflatten_from_structure(
nb::object structure,
const std::vector<array>& values,
int index = 0);

View File

@ -1,81 +0,0 @@
#include "mlx/utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <optional>
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
// Slightly different from the original, with python context on init we are not
// in the context yet. Only create the inner context on enter then delete on
// exit.
class PyStreamContext {
public:
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
_s = s;
}
void enter() {
_inner = new StreamContext(_s);
}
void exit() {
if (_inner != nullptr) {
delete _inner;
_inner = nullptr;
}
}
private:
StreamOrDevice _s;
StreamContext* _inner;
};
void init_utils(py::module_& m) {
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
See :func:`stream` for usage.
Args:
s: The stream or device to set as the default.
)pbdoc")
.def(py::init<StreamOrDevice>(), "s"_a)
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<py::type>& exc_type,
const std::optional<py::object>& exc_value,
const std::optional<py::object>& traceback) { scm.exit(); });
m.def(
"stream",
[](StreamOrDevice s) { return PyStreamContext(s); },
"s"_a,
R"pbdoc(
Create a context manager to set the default device and stream.
Args:
s: The :obj:`Stream` or :obj:`Device` to set as the default.
Returns:
A context manager that sets the default device and stream.
Example:
.. code-block::python
import mlx.core as mx
# Create a context manager for the default device and stream.
with mx.stream(mx.cpu):
# Operations here will use mx.cpu by default.
pass
)pbdoc");
}

View File

@ -1,23 +1,22 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <numeric>
#include <optional>
#include <variant>
#include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/complex.h>
#include <nanobind/stl/variant.h>
#include "mlx/array.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{};
variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
@ -32,31 +31,36 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes;
}
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();
inline array to_array_with_accessor(nb::object obj) {
if (nb::hasattr(obj, "__mlx_array__")) {
return nb::cast<array>(obj.attr("__mlx_array__")());
} else if (nb::isinstance<array>(obj)) {
return nb::cast<array>(obj);
} else {
return obj.cast<array>();
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(int32);
// bool_ is an exception and is always promoted
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(float32);
return array(
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else {
return to_array_with_accessor(std::get<py::object>(v));
return to_array_with_accessor(std::get<nb::object>(v));
}
}
@ -68,14 +72,14 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<py::object>(&a); pa) {
if (auto pa = std::get_if<nb::object>(&a); pa) {
auto arr_a = to_array_with_accessor(*pa);
if (auto pb = std::get_if<py::object>(&b); pb) {
if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b};
}
return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<py::object>(&b); pb) {
} else if (auto pb = std::get_if<nb::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b};
} else {

View File

@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.dtype, mx.bool_)
self.assertEqual(y.item(), True)
# y = mx.array(x, mx.complex64)
# self.assertEqual(y.dtype, mx.complex64)
# self.assertEqual(y.item(), 3.0+0j)
y = mx.array(x, mx.complex64)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.item(), 3.0 + 0j)
def test_array_repr(self):
x = mx.array(True)
@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase):
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(RuntimeError):
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self):
@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqualArray(y, x - 1)
def test_indexing(self):
# Only ellipsis is a no-op
a_mlx = mx.array([1])[...]
self.assertEqual(a_mlx.shape, (1,))
self.assertEqual(a_mlx.item(), 1)
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)
a_mlx = mx.array(a_npy)
@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase):
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(a_tf)
a_mx = mx.array(np.array(a_tf))
for f in [
lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,

View File

@ -134,7 +134,9 @@ class GenerateStubs(Command):
pass
def run(self) -> None:
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"])
subprocess.run(
["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"]
)
# Read the content of README.md
@ -165,7 +167,7 @@ if __name__ == "__main__":
include_package_data=True,
extras_require={
"testing": ["numpy", "torch"],
"dev": ["pre-commit", "pybind11-stubgen"],
"dev": ["pre-commit"],
},
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},