mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
parent
d39ed54f8e
commit
9a8ee00246
@ -31,8 +31,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@ -44,7 +43,8 @@ jobs:
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
python3 setup.py generate_stubs
|
||||
echo "stubs"
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -80,8 +80,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@ -95,7 +94,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@ -144,9 +143,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
@ -161,7 +159,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@ -209,9 +207,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
@ -219,7 +216,7 @@ jobs:
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
pip install . -v
|
||||
python setup.py generate_stubs
|
||||
python -m nanobind.stubgen -m mlx.core -r -O python
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python -m build --wheel
|
||||
|
@ -146,8 +146,12 @@ target_include_directories(
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
|
@ -29,8 +29,8 @@ autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
"https://numpy.org/doc/stable/": None,
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
@ -59,3 +59,14 @@ html_theme_options = {
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
htmlhelp_basename = "mlx_doc"
|
||||
|
||||
|
||||
def setup(app):
|
||||
wrapped = app.registry.documenters["function"].can_document_member
|
||||
|
||||
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool:
|
||||
return "nanobind.nb_func" in str(type(member)) or wrapped(
|
||||
member, *args, **kwargs
|
||||
)
|
||||
|
||||
app.registry.documenters["function"].can_document_member = nanobind_function_patch
|
||||
|
@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
pip install git+https://github.com/wjakob/nanobind.git
|
||||
|
||||
Then simply build and install it using pip:
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
@ -122,8 +121,6 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
|
||||
out_stream->write(header.str().c_str(), header.str().length());
|
||||
out_stream->write(a.data<char>(), a.nbytes());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
|
@ -7,6 +7,25 @@
|
||||
namespace mlx::core {
|
||||
|
||||
struct complex64_t;
|
||||
struct complex128_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool can_convert_to_complex128 =
|
||||
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
|
||||
|
||||
struct complex128_t : public std::complex<double> {
|
||||
complex128_t(double v, double u) : std::complex<double>(v, u){};
|
||||
complex128_t(std::complex<double> v) : std::complex<double>(v){};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
|
||||
complex128_t(T x) : std::complex<double>(x){};
|
||||
|
||||
operator float() const {
|
||||
return real();
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool can_convert_to_complex64 =
|
||||
|
@ -1,3 +1,7 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
|
||||
"cmake>=3.24",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@ -1,7 +1,10 @@
|
||||
pybind11_add_module(
|
||||
nanobind_add_module(
|
||||
core
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
@ -15,7 +18,6 @@ pybind11_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||
|
File diff suppressed because it is too large
Load Diff
122
python/src/buffer.h
Normal file
122
python/src/buffer.h
Normal file
@ -0,0 +1,122 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
#include <optional>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
// Only defined in >= Python 3.9
|
||||
// https://github.com/python/cpython/blob/f6cdc6b4a191b75027de342aa8b5d344fb31313e/Include/typeslots.h#L2-L3
|
||||
#ifndef Py_bf_getbuffer
|
||||
#define Py_bf_getbuffer 1
|
||||
#define Py_bf_releasebuffer 2
|
||||
#endif
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
std::string buffer_format(const array& a) {
|
||||
// https://docs.python.org/3.10/library/struct.html#format-characters
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return "?";
|
||||
case uint8:
|
||||
return "B";
|
||||
case uint16:
|
||||
return "H";
|
||||
case uint32:
|
||||
return "I";
|
||||
case uint64:
|
||||
return "Q";
|
||||
case int8:
|
||||
return "b";
|
||||
case int16:
|
||||
return "h";
|
||||
case int32:
|
||||
return "i";
|
||||
case int64:
|
||||
return "q";
|
||||
case float16:
|
||||
return "e";
|
||||
case float32:
|
||||
return "f";
|
||||
case bfloat16:
|
||||
return "B";
|
||||
case complex64:
|
||||
return "Zf\0";
|
||||
default: {
|
||||
std::ostringstream os;
|
||||
os << "bad dtype: " << a.dtype();
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct buffer_info {
|
||||
std::string format;
|
||||
std::vector<ssize_t> shape;
|
||||
std::vector<ssize_t> strides;
|
||||
|
||||
buffer_info(
|
||||
const std::string& format,
|
||||
std::vector<ssize_t> shape_in,
|
||||
std::vector<ssize_t> strides_in)
|
||||
: format(format),
|
||||
shape(std::move(shape_in)),
|
||||
strides(std::move(strides_in)) {}
|
||||
|
||||
buffer_info(const buffer_info&) = delete;
|
||||
buffer_info& operator=(const buffer_info&) = delete;
|
||||
|
||||
buffer_info(buffer_info&& other) noexcept {
|
||||
(*this) = std::move(other);
|
||||
}
|
||||
|
||||
buffer_info& operator=(buffer_info&& rhs) noexcept {
|
||||
format = std::move(rhs.format);
|
||||
shape = std::move(rhs.shape);
|
||||
strides = std::move(rhs.strides);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
|
||||
std::memset(view, 0, sizeof(Py_buffer));
|
||||
auto a = nb::cast<array>(nb::handle(obj));
|
||||
|
||||
if (!a.is_evaled()) {
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
|
||||
std::vector<ssize_t> shape(a.shape().begin(), a.shape().end());
|
||||
std::vector<ssize_t> strides(a.strides().begin(), a.strides().end());
|
||||
for (auto& s : strides) {
|
||||
s *= a.itemsize();
|
||||
}
|
||||
buffer_info* info =
|
||||
new buffer_info(buffer_format(a), std::move(shape), std::move(strides));
|
||||
|
||||
view->obj = obj;
|
||||
view->ndim = a.ndim();
|
||||
view->internal = info;
|
||||
view->buf = a.data<void>();
|
||||
view->itemsize = a.itemsize();
|
||||
view->len = a.size();
|
||||
view->readonly = false;
|
||||
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
|
||||
view->format = const_cast<char*>(info->format.c_str());
|
||||
}
|
||||
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
|
||||
view->strides = info->strides.data();
|
||||
view->shape = info->shape.data();
|
||||
}
|
||||
Py_INCREF(view->obj);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C" inline void releasebuffer(PyObject*, Py_buffer* view) {
|
||||
delete (buffer_info*)view->internal;
|
||||
}
|
@ -1,11 +1,11 @@
|
||||
// init_constants.cpp
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <limits>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_constants(py::module_& m) {
|
||||
void init_constants(nb::module_& m) {
|
||||
m.attr("Inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("Infinity") = std::numeric_limits<double>::infinity();
|
||||
m.attr("NAN") = NAN;
|
||||
@ -19,6 +19,6 @@ void init_constants(py::module_& m) {
|
||||
m.attr("inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("infty") = std::numeric_limits<double>::infinity();
|
||||
m.attr("nan") = NAN;
|
||||
m.attr("newaxis") = pybind11::none();
|
||||
m.attr("newaxis") = nb::none();
|
||||
m.attr("pi") = 3.1415926535897932384626433;
|
||||
}
|
155
python/src/convert.cpp
Normal file
155
python/src/convert.cpp
Normal file
@ -0,0 +1,155 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <nanobind/stl/complex.h>
|
||||
|
||||
#include "python/src/convert.h"
|
||||
|
||||
namespace nanobind {
|
||||
template <>
|
||||
struct ndarray_traits<float16_t> {
|
||||
static constexpr bool is_complex = false;
|
||||
static constexpr bool is_float = true;
|
||||
static constexpr bool is_bool = false;
|
||||
static constexpr bool is_int = false;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ndarray_traits<bfloat16_t> {
|
||||
static constexpr bool is_complex = false;
|
||||
static constexpr bool is_float = true;
|
||||
static constexpr bool is_bool = false;
|
||||
static constexpr bool is_int = false;
|
||||
static constexpr bool is_signed = true;
|
||||
};
|
||||
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
template <typename T>
|
||||
array nd_array_to_mlx_contiguous(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype) {
|
||||
// Make a copy of the numpy buffer
|
||||
// Get buffer ptr pass to array constructor
|
||||
auto data_ptr = nd_array.data();
|
||||
return array(static_cast<const T*>(data_ptr), shape, dtype);
|
||||
}
|
||||
|
||||
array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
std::vector<int> shape;
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(nd_array.shape(i));
|
||||
}
|
||||
auto type = nd_array.dtype();
|
||||
|
||||
// Copy data and make array
|
||||
if (type == nb::dtype<bool>()) {
|
||||
return nd_array_to_mlx_contiguous<bool>(
|
||||
nd_array, shape, dtype.value_or(bool_));
|
||||
} else if (type == nb::dtype<uint8_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint8_t>(
|
||||
nd_array, shape, dtype.value_or(uint8));
|
||||
} else if (type == nb::dtype<uint16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint16_t>(
|
||||
nd_array, shape, dtype.value_or(uint16));
|
||||
} else if (type == nb::dtype<uint32_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint32_t>(
|
||||
nd_array, shape, dtype.value_or(uint32));
|
||||
} else if (type == nb::dtype<uint64_t>()) {
|
||||
return nd_array_to_mlx_contiguous<uint64_t>(
|
||||
nd_array, shape, dtype.value_or(uint64));
|
||||
} else if (type == nb::dtype<int8_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int8_t>(
|
||||
nd_array, shape, dtype.value_or(int8));
|
||||
} else if (type == nb::dtype<int16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int16_t>(
|
||||
nd_array, shape, dtype.value_or(int16));
|
||||
} else if (type == nb::dtype<int32_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int32_t>(
|
||||
nd_array, shape, dtype.value_or(int32));
|
||||
} else if (type == nb::dtype<int64_t>()) {
|
||||
return nd_array_to_mlx_contiguous<int64_t>(
|
||||
nd_array, shape, dtype.value_or(int64));
|
||||
} else if (type == nb::dtype<float16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<float16_t>(
|
||||
nd_array, shape, dtype.value_or(float16));
|
||||
} else if (type == nb::dtype<bfloat16_t>()) {
|
||||
return nd_array_to_mlx_contiguous<bfloat16_t>(
|
||||
nd_array, shape, dtype.value_or(bfloat16));
|
||||
} else if (type == nb::dtype<float>()) {
|
||||
return nd_array_to_mlx_contiguous<float>(
|
||||
nd_array, shape, dtype.value_or(float32));
|
||||
} else if (type == nb::dtype<double>()) {
|
||||
return nd_array_to_mlx_contiguous<double>(
|
||||
nd_array, shape, dtype.value_or(float32));
|
||||
} else if (type == nb::dtype<std::complex<float>>()) {
|
||||
return nd_array_to_mlx_contiguous<complex64_t>(
|
||||
nd_array, shape, dtype.value_or(complex64));
|
||||
} else if (type == nb::dtype<std::complex<double>>()) {
|
||||
return nd_array_to_mlx_contiguous<complex128_t>(
|
||||
nd_array, shape, dtype.value_or(complex64));
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot convert numpy array to mlx array.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Lib, typename T>
|
||||
nb::ndarray<Lib> mlx_to_nd_array(
|
||||
array a,
|
||||
std::optional<nb::dlpack::dtype> t = {}) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
||||
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
||||
return nb::ndarray<Lib>(
|
||||
a.data<T>(),
|
||||
a.ndim(),
|
||||
shape.data(),
|
||||
nb::handle(),
|
||||
strides.data(),
|
||||
t.value_or(nb::dtype<T>()));
|
||||
}
|
||||
|
||||
template <typename Lib>
|
||||
nb::ndarray<Lib> mlx_to_nd_array(const array& a) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return mlx_to_nd_array<Lib, bool>(a);
|
||||
case uint8:
|
||||
return mlx_to_nd_array<Lib, uint8_t>(a);
|
||||
case uint16:
|
||||
return mlx_to_nd_array<Lib, uint16_t>(a);
|
||||
case uint32:
|
||||
return mlx_to_nd_array<Lib, uint32_t>(a);
|
||||
case uint64:
|
||||
return mlx_to_nd_array<Lib, uint64_t>(a);
|
||||
case int8:
|
||||
return mlx_to_nd_array<Lib, int8_t>(a);
|
||||
case int16:
|
||||
return mlx_to_nd_array<Lib, int16_t>(a);
|
||||
case int32:
|
||||
return mlx_to_nd_array<Lib, int32_t>(a);
|
||||
case int64:
|
||||
return mlx_to_nd_array<Lib, int64_t>(a);
|
||||
case float16:
|
||||
return mlx_to_nd_array<Lib, float16_t>(a);
|
||||
case bfloat16:
|
||||
return mlx_to_nd_array<Lib, bfloat16_t>(a, nb::bfloat16);
|
||||
case float32:
|
||||
return mlx_to_nd_array<Lib, float>(a);
|
||||
case complex64:
|
||||
return mlx_to_nd_array<Lib, std::complex<float>>(a);
|
||||
}
|
||||
}
|
||||
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
||||
return mlx_to_nd_array<nb::numpy>(a);
|
||||
}
|
16
python/src/convert.h
Normal file
16
python/src/convert.h
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/ndarray.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype);
|
||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
@ -1,32 +1,34 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
auto device_class = py::class_<Device>(
|
||||
void init_device(nb::module_& m) {
|
||||
auto device_class = nb::class_<Device>(
|
||||
m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
nb::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
.export_values()
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Device::DeviceType& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
},
|
||||
py::prepend());
|
||||
.def("__eq__", [](const Device::DeviceType& d, const nb::object& other) {
|
||||
if (!nb::isinstance<Device>(other) &&
|
||||
!nb::isinstance<Device::DeviceType>(other)) {
|
||||
return false;
|
||||
}
|
||||
return d == nb::cast<Device>(other);
|
||||
});
|
||||
|
||||
device_class.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_readonly("type", &Device::type)
|
||||
device_class.def(nb::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_ro("type", &Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Device& d) {
|
||||
@ -34,11 +36,15 @@ void init_device(py::module_& m) {
|
||||
os << d;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Device& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
.def("__eq__", [](const Device& d, const nb::object& other) {
|
||||
if (!nb::isinstance<Device>(other) &&
|
||||
!nb::isinstance<Device::DeviceType>(other)) {
|
||||
return false;
|
||||
}
|
||||
return d == nb::cast<Device>(other);
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def(
|
||||
"default_device",
|
||||
|
@ -1,20 +1,17 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_extensions(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
void init_fast(nb::module_& parent_module) {
|
||||
auto m =
|
||||
parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
|
||||
|
||||
@ -31,15 +28,15 @@ void init_extensions(py::module_& parent_module) {
|
||||
},
|
||||
"a"_a,
|
||||
"dims"_a,
|
||||
py::kw_only(),
|
||||
nb::kw_only(),
|
||||
"traditional"_a,
|
||||
"base"_a,
|
||||
"scale"_a,
|
||||
"offset"_a,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
Args:
|
||||
@ -70,30 +67,34 @@ void init_extensions(py::module_& parent_module) {
|
||||
"q"_a,
|
||||
"k"_a,
|
||||
"v"_a,
|
||||
py::kw_only(),
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = none,
|
||||
"stream"_a = none,
|
||||
"mask"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
|
||||
Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
|
||||
Supports:
|
||||
* [Multi-Head Attention](https://arxiv.org/abs/1706.03762)
|
||||
* [Grouped Query Attention](https://arxiv.org/abs/2305.13245)
|
||||
* [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
|
||||
|
||||
This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
|
||||
Note: The softmax operation is performed in ``float32`` regardless of
|
||||
input precision.
|
||||
|
||||
Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
v (array): Input values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
v (array): Input values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
|
||||
)pbdoc");
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -1,19 +1,20 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_fft(py::module_& parent_module) {
|
||||
void init_fft(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
||||
m.def(
|
||||
@ -29,9 +30,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"n"_a = nb::none(),
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform.
|
||||
|
||||
@ -59,9 +60,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"n"_a = nb::none(),
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
One dimensional inverse discrete Fourier Transform.
|
||||
|
||||
@ -95,9 +96,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a.none() = std::vector<int>{-2, -1},
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Two dimensional discrete Fourier Transform.
|
||||
|
||||
@ -132,9 +133,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a.none() = std::vector<int>{-2, -1},
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Two dimensional inverse discrete Fourier Transform.
|
||||
|
||||
@ -169,9 +170,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
n-dimensional discrete Fourier Transform.
|
||||
|
||||
@ -207,9 +208,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
n-dimensional inverse discrete Fourier Transform.
|
||||
|
||||
@ -239,9 +240,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"n"_a = nb::none(),
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
One dimensional discrete Fourier Transform on a real input.
|
||||
|
||||
@ -274,9 +275,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"n"_a = none,
|
||||
"n"_a = nb::none(),
|
||||
"axis"_a = -1,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft`.
|
||||
|
||||
@ -314,9 +315,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a.none() = std::vector<int>{-2, -1},
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Two dimensional real discrete Fourier Transform.
|
||||
|
||||
@ -357,9 +358,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = std::vector<int>{-2, -1},
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a.none() = std::vector<int>{-2, -1},
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfft2`.
|
||||
|
||||
@ -400,9 +401,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
n-dimensional real discrete Fourier Transform.
|
||||
|
||||
@ -443,9 +444,9 @@ void init_fft(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"s"_a = none,
|
||||
"axes"_a = none,
|
||||
"stream"_a = none,
|
||||
"s"_a = nb::none(),
|
||||
"axes"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
The inverse of :func:`rfftn`.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
@ -7,19 +7,19 @@
|
||||
|
||||
#include "mlx/ops.h"
|
||||
|
||||
bool is_none_slice(const py::slice& in_slice) {
|
||||
bool is_none_slice(const nb::slice& in_slice) {
|
||||
return (
|
||||
py::getattr(in_slice, "start").is_none() &&
|
||||
py::getattr(in_slice, "stop").is_none() &&
|
||||
py::getattr(in_slice, "step").is_none());
|
||||
nb::getattr(in_slice, "start").is_none() &&
|
||||
nb::getattr(in_slice, "stop").is_none() &&
|
||||
nb::getattr(in_slice, "step").is_none());
|
||||
}
|
||||
|
||||
int get_slice_int(py::object obj, int default_val) {
|
||||
int get_slice_int(nb::object obj, int default_val) {
|
||||
if (!obj.is_none()) {
|
||||
if (!py::isinstance<py::int_>(obj)) {
|
||||
if (!nb::isinstance<nb::int_>(obj)) {
|
||||
throw std::invalid_argument("Slice indices must be integers or None.");
|
||||
}
|
||||
return py::cast<int>(py::cast<py::int_>(obj));
|
||||
return nb::cast<int>(nb::cast<nb::int_>(obj));
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
@ -28,7 +28,7 @@ void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
const py::slice& in_slice,
|
||||
const nb::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
// Assume n is the number of elements in the dimension being sliced.
|
||||
@ -36,26 +36,26 @@ void get_slice_params(
|
||||
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
|
||||
// k < 0 . If k is not given it defaults to 1
|
||||
|
||||
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
|
||||
strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
|
||||
starts = get_slice_int(
|
||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
ends = get_slice_int(
|
||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
}
|
||||
|
||||
array get_int_index(py::object idx, int axis_size) {
|
||||
int idx_ = py::cast<int>(idx);
|
||||
array get_int_index(nb::object idx, int axis_size) {
|
||||
int idx_ = nb::cast<int>(idx);
|
||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||
|
||||
return array(idx_, uint32);
|
||||
}
|
||||
|
||||
bool is_valid_index_type(const py::object& obj) {
|
||||
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
|
||||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
|
||||
bool is_valid_index_type(const nb::object& obj) {
|
||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
|
||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
|
||||
return take(src, indices, 0);
|
||||
}
|
||||
|
||||
array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
|
||||
array mlx_gather_nd(
|
||||
array src,
|
||||
const std::vector<py::object>& indices,
|
||||
const std::vector<nb::object>& indices,
|
||||
bool gather_first,
|
||||
int& max_dims) {
|
||||
max_dims = 0;
|
||||
@ -117,9 +117,10 @@ array mlx_gather_nd(
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
if (nb::isinstance<nb::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + src.shape(i) : start;
|
||||
@ -128,10 +129,10 @@ array mlx_gather_nd(
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
} else if (py::isinstance<py::int_>(idx)) {
|
||||
} else if (nb::isinstance<nb::int_>(idx)) {
|
||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
auto arr = py::cast<array>(idx);
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
auto arr = nb::cast<array>(idx);
|
||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||
gather_indices.push_back(arr);
|
||||
}
|
||||
@ -185,7 +186,7 @@ array mlx_gather_nd(
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
return src;
|
||||
@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
std::vector<py::object> indices;
|
||||
std::vector<nb::object> indices;
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<py::object> r_indices;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!py::ellipsis().is(idx)) {
|
||||
if (!nb::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (py::ellipsis().is(idx)) {
|
||||
if (nb::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(py::slice(0, src.shape(axis), 1));
|
||||
indices.push_back(nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
}
|
||||
@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
//
|
||||
// Check whether we have arrays or integer indices and delegate to gather_nd
|
||||
// after removing the slices at the end and all Nones.
|
||||
std::vector<py::object> remaining_indices;
|
||||
std::vector<nb::object> remaining_indices;
|
||||
bool have_array = false;
|
||||
{
|
||||
// First check whether the results of gather are going to be 1st or
|
||||
@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<py::object> gather_indices;
|
||||
std::vector<nb::object> gather_indices;
|
||||
for (int i = 0; i <= last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (!idx.is_none()) {
|
||||
@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
if (gather_first) {
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
for (int i = 0; i < last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (idx.is_none()) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
} else if (py::isinstance<py::slice>(idx)) {
|
||||
} else if (nb::isinstance<nb::slice>(idx)) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
} else if (idx.is_none()) {
|
||||
remaining_indices.push_back(idx);
|
||||
} else {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
get_slice_params(
|
||||
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
|
||||
starts[axis],
|
||||
ends[axis],
|
||||
strides[axis],
|
||||
nb::cast<nb::slice>(idx),
|
||||
ends[axis]);
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj) {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, obj);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, py::cast<array>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_get_item_int(src, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, obj);
|
||||
array mlx_get_item(const array& src, const nb::object& obj) {
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, nb::cast<array>(obj));
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
|
||||
} else if (nb::isinstance<nb::ellipsis>(obj)) {
|
||||
return src;
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||
const array& src,
|
||||
const py::int_& idx,
|
||||
const nb::int_& idx,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -446,7 +453,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
const array& src,
|
||||
const py::slice& in_slice,
|
||||
const nb::slice& in_slice,
|
||||
const array& update) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
@ -478,9 +485,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
const array& src,
|
||||
const py::tuple& entries,
|
||||
const nb::tuple& entries,
|
||||
const array& update) {
|
||||
std::vector<py::object> indices;
|
||||
std::vector<nb::object> indices;
|
||||
int non_none_indices = 0;
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
@ -494,7 +501,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
} else if (!py::ellipsis().is(idx)) {
|
||||
} else if (!nb::ellipsis().is(idx)) {
|
||||
if (!has_ellipsis) {
|
||||
indices_before++;
|
||||
non_none_indices_before += !idx.is_none();
|
||||
@ -514,7 +521,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.insert(
|
||||
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
|
||||
indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
non_none_indices = src.ndim();
|
||||
} else {
|
||||
@ -549,15 +556,15 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
|
||||
if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
|
||||
have_non_array = have_array;
|
||||
num_slices++;
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
arrays_first = true;
|
||||
}
|
||||
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
|
||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
}
|
||||
}
|
||||
@ -569,10 +576,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (py::isinstance<py::slice>(pyidx)) {
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
auto axis_size = src.shape(ax++);
|
||||
get_slice_params(start, end, stride, pyidx, axis_size);
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
@ -584,13 +592,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
} else if (py::isinstance<py::int_>(pyidx)) {
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
} else if (pyidx.is_none()) {
|
||||
slice_num++;
|
||||
} else if (py::isinstance<array>(pyidx)) {
|
||||
} else if (nb::isinstance<array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = py::cast<array>(pyidx);
|
||||
auto idx = nb::cast<array>(pyidx);
|
||||
std::vector<int> idx_shape;
|
||||
if (!arrays_first) {
|
||||
idx_shape.insert(idx_shape.end(), slice_num, 1);
|
||||
@ -629,24 +637,24 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>>
|
||||
mlx_compute_scatter_args(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_scatter_args_nd(src, obj, vals);
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
|
||||
} else if (obj.is_none()) {
|
||||
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
auto out = scatter(src, indices, updates, axes);
|
||||
@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@ -670,7 +678,7 @@ array mlx_add_item(
|
||||
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@ -682,7 +690,7 @@ array mlx_subtract_item(
|
||||
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@ -694,7 +702,7 @@ array mlx_multiply_item(
|
||||
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@ -706,7 +714,7 @@ array mlx_divide_item(
|
||||
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@ -718,7 +726,7 @@ array mlx_maximum_item(
|
||||
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
|
@ -1,38 +1,38 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj);
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
||||
array mlx_get_item(const array& src, const nb::object& obj);
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v);
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
|
@ -1,32 +1,29 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
namespace {
|
||||
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto result = svd(a, s);
|
||||
return py::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void init_linalg(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
void init_linalg(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||
|
||||
@ -59,16 +56,15 @@ void init_linalg(py::module_& parent_module) {
|
||||
return norm(a, ord, axis, keepdims, stream);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"ord"_a = none,
|
||||
"axis"_a = none,
|
||||
nb::arg(),
|
||||
"ord"_a = nb::none(),
|
||||
"axis"_a = nb::none(),
|
||||
"keepdims"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Matrix or vector norm.
|
||||
|
||||
This function computes vector or matrix norms depending on the value of
|
||||
@ -188,11 +184,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"qr",
|
||||
&qr,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
|
||||
R"pbdoc(
|
||||
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
|
||||
|
||||
The QR factorization of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. The matrices
|
||||
@ -221,11 +217,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"svd",
|
||||
&svd_helper,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
|
||||
R"pbdoc(
|
||||
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
|
||||
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
@ -245,11 +241,11 @@ void init_linalg(py::module_& parent_module) {
|
||||
"inv",
|
||||
&inv,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
inv(a: array, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Compute the inverse of a square matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
|
@ -1,8 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
@ -16,39 +14,39 @@
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_istream_object(const py::object& file) {
|
||||
return py::hasattr(file, "readinto") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
bool is_istream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_ostream_object(const py::object& file) {
|
||||
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
bool is_ostream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
|
||||
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
|
||||
if (is_istream_object(file)) {
|
||||
auto st_pos = file.attr("tell")();
|
||||
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
|
||||
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
|
||||
file.attr("seek")(st_pos, 0);
|
||||
return r;
|
||||
}
|
||||
return zipfile.attr("is_zipfile")(file).cast<bool>();
|
||||
return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
|
||||
}
|
||||
|
||||
class ZipFileWrapper {
|
||||
public:
|
||||
ZipFileWrapper(
|
||||
const py::module_& zipfile,
|
||||
const py::object& file,
|
||||
const nb::module_& zipfile,
|
||||
const nb::object& file,
|
||||
char mode = 'r',
|
||||
int compression = 0)
|
||||
: zipfile_module_(zipfile),
|
||||
@ -63,10 +61,10 @@ class ZipFileWrapper {
|
||||
close_func_(zipfile_object_.attr("close")) {}
|
||||
|
||||
std::vector<std::string> namelist() const {
|
||||
return files_list_.cast<std::vector<std::string>>();
|
||||
return nb::cast<std::vector<std::string>>(files_list_);
|
||||
}
|
||||
|
||||
py::object open(const std::string& key, char mode = 'r') {
|
||||
nb::object open(const std::string& key, char mode = 'r') {
|
||||
// Following numpy :
|
||||
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
|
||||
if (mode == 'w') {
|
||||
@ -76,12 +74,12 @@ class ZipFileWrapper {
|
||||
}
|
||||
|
||||
private:
|
||||
py::module_ zipfile_module_;
|
||||
py::object zipfile_object_;
|
||||
py::list files_list_;
|
||||
py::object open_func_;
|
||||
py::object read_func_;
|
||||
py::object close_func_;
|
||||
nb::module_ zipfile_module_;
|
||||
nb::object zipfile_object_;
|
||||
nb::list files_list_;
|
||||
nb::object open_func_;
|
||||
nb::object read_func_;
|
||||
nb::object close_func_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -90,14 +88,14 @@ class ZipFileWrapper {
|
||||
|
||||
class PyFileReader : public io::Reader {
|
||||
public:
|
||||
PyFileReader(py::object file)
|
||||
PyFileReader(nb::object file)
|
||||
: pyistream_(file),
|
||||
readinto_func_(file.attr("readinto")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileReader() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
pyistream_.release().dec_ref();
|
||||
readinto_func_.release().dec_ref();
|
||||
@ -108,8 +106,8 @@ class PyFileReader : public io::Reader {
|
||||
bool is_open() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyistream_.attr("closed").cast<bool>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !nb::cast<bool>(pyistream_.attr("closed"));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -117,7 +115,7 @@ class PyFileReader : public io::Reader {
|
||||
bool good() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !pyistream_.is_none();
|
||||
}
|
||||
return out;
|
||||
@ -126,25 +124,24 @@ class PyFileReader : public io::Reader {
|
||||
size_t tell() const override {
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = nb::cast<size_t>(tell_func_());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
||||
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
||||
|
||||
py::object bytes_read =
|
||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
|
||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
|
||||
throw std::runtime_error("[load] Failed to read from python stream");
|
||||
}
|
||||
}
|
||||
@ -154,23 +151,23 @@ class PyFileReader : public io::Reader {
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyistream_;
|
||||
py::object readinto_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
nb::object pyistream_;
|
||||
nb::object readinto_func_;
|
||||
nb::object seek_func_;
|
||||
nb::object tell_func_;
|
||||
};
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(py::cast<std::string>(file), s);
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
arr.eval();
|
||||
}
|
||||
@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
py::object file,
|
||||
nb::object file,
|
||||
StreamOrDevice s) {
|
||||
bool own_file = py::isinstance<py::str>(file);
|
||||
bool own_file = nb::isinstance<nb::str>(file);
|
||||
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
if (!is_zip_file(zipfile, file)) {
|
||||
throw std::invalid_argument(
|
||||
"[load_npz] Input must be a zip file or a file-like object that can be "
|
||||
@ -208,7 +205,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
ZipFileWrapper zipfile_object(zipfile, file);
|
||||
for (const std::string& st : zipfile_object.namelist()) {
|
||||
// Open zip file as a python file stream
|
||||
py::object sub_file = zipfile_object.open(st);
|
||||
nb::object sub_file = zipfile_object.open(st);
|
||||
|
||||
// Create array from python fille stream
|
||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
@ -224,7 +221,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
if (!own_file) {
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : array_dict) {
|
||||
arr.eval();
|
||||
}
|
||||
@ -233,14 +230,14 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
return array_dict;
|
||||
}
|
||||
|
||||
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||
return load(py::cast<std::string>(file), s);
|
||||
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||
return load(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
arr.eval();
|
||||
}
|
||||
return arr;
|
||||
@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
nb::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
fname = py::cast<std::string>(file);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
fname = nb::cast<std::string>(file);
|
||||
} else if (is_istream_object(file)) {
|
||||
fname = file.attr("name").cast<std::string>();
|
||||
fname = nb::cast<std::string>(file.attr("name"));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[load] Input must be a file-like object opened in binary mode, or string");
|
||||
@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper(
|
||||
|
||||
class PyFileWriter : public io::Writer {
|
||||
public:
|
||||
PyFileWriter(py::object file)
|
||||
PyFileWriter(nb::object file)
|
||||
: pyostream_(file),
|
||||
write_func_(file.attr("write")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileWriter() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
pyostream_.release().dec_ref();
|
||||
write_func_.release().dec_ref();
|
||||
@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer {
|
||||
bool is_open() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyostream_.attr("closed").cast<bool>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !nb::cast<bool>(pyostream_.attr("closed"));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer {
|
||||
bool good() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !pyostream_.is_none();
|
||||
}
|
||||
return out;
|
||||
@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer {
|
||||
size_t tell() const override {
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = nb::cast<size_t>(tell_func_());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
py::object bytes_written =
|
||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
auto memview =
|
||||
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
|
||||
nb::object bytes_written = write_func_(nb::handle(memview));
|
||||
|
||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
|
||||
throw std::runtime_error("[load] Failed to write to python stream");
|
||||
}
|
||||
}
|
||||
@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyostream_;
|
||||
py::object write_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
nb::object pyostream_;
|
||||
nb::object write_func_;
|
||||
nb::object seek_func_;
|
||||
nb::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(py::object file, array a) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a);
|
||||
void mlx_save_helper(nb::object file, array a) {
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
save(nb::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
save(writer, a);
|
||||
}
|
||||
|
||||
@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) {
|
||||
}
|
||||
|
||||
void mlx_savez_helper(
|
||||
py::object file_,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
nb::object file_,
|
||||
nb::args args,
|
||||
const nb::kwargs& kwargs,
|
||||
bool compressed) {
|
||||
// Add .npz to the end of the filename if not already there
|
||||
py::object file = file_;
|
||||
nb::object file = file_;
|
||||
|
||||
if (py::isinstance<py::str>(file_)) {
|
||||
std::string fname = file_.cast<std::string>();
|
||||
if (nb::isinstance<nb::str>(file_)) {
|
||||
std::string fname = nb::cast<std::string>(file_);
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
fname += ".npz";
|
||||
|
||||
file = py::str(fname);
|
||||
file = nb::cast(fname);
|
||||
}
|
||||
|
||||
// Collect args and kwargs
|
||||
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
|
||||
auto arrays_list = args.cast<std::vector<array>>();
|
||||
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
|
||||
auto arrays_list = nb::cast<std::vector<array>>(args);
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
@ -426,9 +424,9 @@ void mlx_savez_helper(
|
||||
}
|
||||
|
||||
// Create python ZipFile object depending on compression
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
|
||||
: zipfile.attr("ZIP_STORED").cast<int>();
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
int compression = nb::cast<int>(
|
||||
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
|
||||
char mode = 'w';
|
||||
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
|
||||
|
||||
@ -438,7 +436,7 @@ void mlx_savez_helper(
|
||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
@ -447,31 +445,31 @@ void mlx_savez_helper(
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<py::dict> m) {
|
||||
nb::object file,
|
||||
nb::dict d,
|
||||
std::optional<nb::dict> m) {
|
||||
std::unordered_map<std::string, std::string> metadata_map;
|
||||
if (m) {
|
||||
try {
|
||||
metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, std::string>>();
|
||||
} catch (const py::cast_error& e) {
|
||||
nb::cast<std::unordered_map<std::string, std::string>>(m.value());
|
||||
} catch (const nb::cast_error& e) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] Metadata must be a dictionary with string keys and values");
|
||||
}
|
||||
} else {
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(writer, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
@ -481,22 +479,22 @@ void mlx_save_safetensor_helper(
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict a,
|
||||
std::optional<py::dict> m) {
|
||||
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
nb::object file,
|
||||
nb::dict a,
|
||||
std::optional<nb::dict> m) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
|
||||
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -1,15 +1,20 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/io.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
using LoadOutputTypes = std::variant<
|
||||
@ -18,27 +23,27 @@ using LoadOutputTypes = std::variant<
|
||||
SafetensorsLoad,
|
||||
GGUFLoad>;
|
||||
|
||||
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
|
||||
SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<py::dict> m);
|
||||
nb::object file,
|
||||
nb::dict d,
|
||||
std::optional<nb::dict> m);
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s);
|
||||
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<py::dict> m);
|
||||
nb::object file,
|
||||
nb::dict d,
|
||||
std::optional<nb::dict> m);
|
||||
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
nb::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(py::object file, array a);
|
||||
void mlx_save_helper(nb::object file, array a);
|
||||
void mlx_savez_helper(
|
||||
py::object file,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
nb::object file,
|
||||
nb::args args,
|
||||
const nb::kwargs& kwargs,
|
||||
bool compressed = false);
|
||||
|
@ -1,16 +1,15 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
void init_metal(nb::module_& m) {
|
||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def(
|
||||
"is_available",
|
||||
&metal::is_available,
|
||||
@ -48,7 +47,7 @@ void init_metal(py::module_& m) {
|
||||
"set_memory_limit",
|
||||
&metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
py::kw_only(),
|
||||
nb::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
@ -1,30 +1,30 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Conbright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_array(py::module_&);
|
||||
void init_device(py::module_&);
|
||||
void init_stream(py::module_&);
|
||||
void init_metal(py::module_&);
|
||||
void init_ops(py::module_&);
|
||||
void init_transforms(py::module_&);
|
||||
void init_random(py::module_&);
|
||||
void init_fft(py::module_&);
|
||||
void init_linalg(py::module_&);
|
||||
void init_constants(py::module_&);
|
||||
void init_extensions(py::module_&);
|
||||
void init_utils(py::module_&);
|
||||
void init_array(nb::module_&);
|
||||
void init_device(nb::module_&);
|
||||
void init_stream(nb::module_&);
|
||||
void init_metal(nb::module_&);
|
||||
void init_ops(nb::module_&);
|
||||
void init_transforms(nb::module_&);
|
||||
void init_random(nb::module_&);
|
||||
void init_fft(nb::module_&);
|
||||
void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
|
||||
PYBIND11_MODULE(core, m) {
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
|
||||
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
|
||||
py::module_::import("mlx._os_warning");
|
||||
auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
|
||||
nb::module_::import_("mlx._os_warning");
|
||||
nb::set_leak_warnings(false);
|
||||
|
||||
init_device(m);
|
||||
init_stream(m);
|
||||
@ -36,8 +36,7 @@ PYBIND11_MODULE(core, m) {
|
||||
init_fft(m);
|
||||
init_linalg(m);
|
||||
init_constants(m);
|
||||
init_extensions(m);
|
||||
init_utils(m);
|
||||
init_fast(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
1796
python/src/ops.cpp
1796
python/src/ops.cpp
File diff suppressed because it is too large
Load Diff
@ -1,60 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
// A patch to get float16_t to work with pybind11 numpy arrays
|
||||
// Derived from:
|
||||
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
namespace pybind11::detail {
|
||||
|
||||
template <typename T>
|
||||
struct npy_scalar_caster {
|
||||
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
|
||||
using Array = array_t<T>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
|
||||
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
|
||||
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
|
||||
return false;
|
||||
Array tmp = Array::ensure(src);
|
||||
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
|
||||
this->value = *tmp.data();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(T src, return_value_policy, handle) {
|
||||
Array tmp({1});
|
||||
tmp.mutable_at(0) = src;
|
||||
tmp.resize({});
|
||||
// You could also just return the array if you want a scalar array.
|
||||
object scalar = tmp[tuple()];
|
||||
return scalar.release();
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
|
||||
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
|
||||
constexpr int NPY_FLOAT16 = 23;
|
||||
|
||||
// Kinda following:
|
||||
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
|
||||
template <>
|
||||
struct npy_format_descriptor<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
static pybind11::dtype dtype() {
|
||||
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
};
|
||||
|
||||
} // namespace pybind11::detail
|
@ -1,7 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
@ -9,8 +12,8 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
@ -25,22 +28,22 @@ class PyKeySequence {
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(py::cast<array>(state_[0]));
|
||||
auto out = split(nb::cast<array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
|
||||
py::list state() {
|
||||
nb::list state() {
|
||||
return state_;
|
||||
}
|
||||
|
||||
void release() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
state_.release().dec_ref();
|
||||
}
|
||||
|
||||
private:
|
||||
py::list state_;
|
||||
nb::list state_;
|
||||
};
|
||||
|
||||
PyKeySequence& default_key() {
|
||||
@ -54,7 +57,7 @@ PyKeySequence& default_key() {
|
||||
return ks;
|
||||
}
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
void init_random(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
@ -85,10 +88,10 @@ void init_random(py::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"split",
|
||||
py::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
nb::overload_cast<const array&, int, StreamOrDevice>(&random::split),
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = none,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@ -119,9 +122,9 @@ void init_random(py::module_& parent_module) {
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
@ -151,11 +154,11 @@ void init_random(py::module_& parent_module) {
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"dtype"_a.none() = float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
@ -184,9 +187,9 @@ void init_random(py::module_& parent_module) {
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = int32,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"dtype"_a.none() = int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
@ -219,9 +222,9 @@ void init_random(py::module_& parent_module) {
|
||||
}
|
||||
},
|
||||
"p"_a = 0.5,
|
||||
"shape"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
@ -259,10 +262,10 @@ void init_random(py::module_& parent_module) {
|
||||
},
|
||||
"lower"_a,
|
||||
"upper"_a,
|
||||
"shape"_a = none,
|
||||
"dtype"_a = std::optional{float32},
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
@ -292,9 +295,9 @@ void init_random(py::module_& parent_module) {
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"stream"_a = none,
|
||||
"key"_a = none,
|
||||
"dtype"_a.none() = float32,
|
||||
"stream"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
@ -331,10 +334,10 @@ void init_random(py::module_& parent_module) {
|
||||
},
|
||||
"logits"_a,
|
||||
"axis"_a = -1,
|
||||
"shape"_a = none,
|
||||
"num_samples"_a = none,
|
||||
"key"_a = none,
|
||||
"stream"_a = none,
|
||||
"shape"_a = nb::none(),
|
||||
"num_samples"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
@ -359,6 +362,6 @@ void init_random(py::module_& parent_module) {
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
|
||||
}
|
||||
|
@ -1,25 +1,54 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_stream(py::module_& m) {
|
||||
py::class_<Stream>(
|
||||
// Create the StreamContext on enter and delete on exit.
|
||||
class PyStreamContext {
|
||||
public:
|
||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
_s = s;
|
||||
}
|
||||
|
||||
void enter() {
|
||||
_inner = new StreamContext(_s);
|
||||
}
|
||||
|
||||
void exit() {
|
||||
if (_inner != nullptr) {
|
||||
delete _inner;
|
||||
_inner = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
StreamOrDevice _s;
|
||||
StreamContext* _inner;
|
||||
};
|
||||
|
||||
void init_stream(nb::module_& m) {
|
||||
nb::class_<Stream>(
|
||||
m,
|
||||
"Stream",
|
||||
R"pbdoc(
|
||||
A stream for running operations on a given device.
|
||||
)pbdoc")
|
||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_readonly("device", &Stream::device)
|
||||
.def(nb::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_ro("device", &Stream::device)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Stream& s) {
|
||||
@ -31,7 +60,7 @@ void init_stream(py::module_& m) {
|
||||
return s1 == s2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def(
|
||||
"default_stream",
|
||||
@ -56,4 +85,48 @@ void init_stream(py::module_& m) {
|
||||
&new_stream,
|
||||
"device"_a,
|
||||
R"pbdoc(Make a new stream on the given device.)pbdoc");
|
||||
|
||||
nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||||
A context manager for setting the current device and stream.
|
||||
|
||||
See :func:`stream` for usage.
|
||||
|
||||
Args:
|
||||
s: The stream or device to set as the default.
|
||||
)pbdoc")
|
||||
.def(nb::init<StreamOrDevice>(), "s"_a)
|
||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](PyStreamContext& scm,
|
||||
const std::optional<nb::type_object>& exc_type,
|
||||
const std::optional<nb::object>& exc_value,
|
||||
const std::optional<nb::object>& traceback) { scm.exit(); },
|
||||
"exc_type"_a = nb::none(),
|
||||
"exc_value"_a = nb::none(),
|
||||
"traceback"_a = nb::none());
|
||||
m.def(
|
||||
"stream",
|
||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||
"s"_a,
|
||||
R"pbdoc(
|
||||
Create a context manager to set the default device and stream.
|
||||
|
||||
Args:
|
||||
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the default device and stream.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Create a context manager for the default device and stream.
|
||||
with mx.stream(mx.cpu):
|
||||
# Operations here will use mx.cpu by default.
|
||||
pass
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -1,6 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
@ -13,13 +18,17 @@
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
|
||||
inline std::string type_name_str(const nb::handle& o) {
|
||||
return nb::cast<std::string>(nb::type_name(o.type()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
std::vector<T> vals;
|
||||
@ -49,7 +58,7 @@ auto validate_argnums_argnames(
|
||||
}
|
||||
|
||||
auto py_value_and_grad(
|
||||
const py::function& fun,
|
||||
const nb::callable& fun,
|
||||
std::vector<int> argnums,
|
||||
std::vector<std::string> argnames,
|
||||
const std::string& error_msg_tag,
|
||||
@ -71,7 +80,7 @@ auto py_value_and_grad(
|
||||
}
|
||||
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
const py::args& args, const py::kwargs& kwargs) {
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
std::ostringstream msg;
|
||||
@ -89,7 +98,7 @@ auto py_value_and_grad(
|
||||
<< "' because the function is called with the "
|
||||
<< "following keyword arguments {";
|
||||
for (auto item : kwargs) {
|
||||
msg << item.first.cast<std::string>() << ",";
|
||||
msg << nb::cast<std::string>(item.first) << ",";
|
||||
}
|
||||
msg << "}";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@ -115,7 +124,7 @@ auto py_value_and_grad(
|
||||
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_value_out;
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
@ -127,15 +136,15 @@ auto py_value_and_grad(
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
// Copy the arguments
|
||||
py::args args_cpy = py::tuple(args.size());
|
||||
py::kwargs kwargs_cpy = py::kwargs();
|
||||
nb::list args_cpy;
|
||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||
int j = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
|
||||
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
|
||||
j++;
|
||||
} else {
|
||||
args_cpy[i] = args[i];
|
||||
args_cpy.append(args[i]);
|
||||
}
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
@ -154,25 +163,25 @@ auto py_value_and_grad(
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!py::isinstance<array>(py_value_out)) {
|
||||
if (!nb::isinstance<array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be a "
|
||||
<< "scalar array; but " << py_value_out.get_type()
|
||||
<< "scalar array; but " << type_name_str(py_value_out)
|
||||
<< " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<py::tuple>(py_value_out)) {
|
||||
if (!nb::isinstance<nb::tuple>(py_value_out)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< py_value_out.get_type() << " was returned.";
|
||||
<< type_name_str(py_value_out) << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
py::tuple ret = py::cast<py::tuple>(py_value_out);
|
||||
nb::tuple ret = nb::cast<nb::tuple>(py_value_out);
|
||||
if (ret.size() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
@ -182,14 +191,14 @@ auto py_value_and_grad(
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<array>(ret[0])) {
|
||||
if (!nb::isinstance<array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< ret[0].get_type() << " .";
|
||||
<< type_name_str(ret[0]) << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@ -212,61 +221,60 @@ auto py_value_and_grad(
|
||||
// In case 2 we return a tuple of the above.
|
||||
// In case 3 we return a tuple containing a tuple and dict (sth like
|
||||
// (tuple(), dict(x=mx.array(5))) ).
|
||||
py::object positional_grads;
|
||||
py::object keyword_grads;
|
||||
py::object py_grads;
|
||||
nb::object positional_grads;
|
||||
nb::object keyword_grads;
|
||||
nb::object py_grads;
|
||||
|
||||
// Collect the gradients for the positional arguments
|
||||
if (argnums.size() == 1) {
|
||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||
} else if (argnums.size() > 1) {
|
||||
py::tuple grads_(argnums.size());
|
||||
nb::list grads_;
|
||||
for (int i = 0; i < argnums.size(); i++) {
|
||||
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
|
||||
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
|
||||
}
|
||||
positional_grads = py::cast<py::object>(grads_);
|
||||
positional_grads = nb::tuple(grads_);
|
||||
} else {
|
||||
positional_grads = py::none();
|
||||
positional_grads = nb::none();
|
||||
}
|
||||
|
||||
// No keyword argument gradients so return the tuple of gradients
|
||||
if (argnames.size() == 0) {
|
||||
py_grads = positional_grads;
|
||||
} else {
|
||||
py::dict grads_;
|
||||
nb::dict grads_;
|
||||
for (int i = 0; i < argnames.size(); i++) {
|
||||
auto& k = argnames[i];
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||
}
|
||||
keyword_grads = py::cast<py::object>(grads_);
|
||||
keyword_grads = grads_;
|
||||
|
||||
py_grads =
|
||||
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
|
||||
py_grads = nb::make_tuple(positional_grads, keyword_grads);
|
||||
}
|
||||
|
||||
// Put the values back in the container
|
||||
py::object return_value = tree_unflatten(py_value_out, value);
|
||||
nb::object return_value = tree_unflatten(py_value_out, value);
|
||||
return std::make_pair(return_value, py_grads);
|
||||
};
|
||||
}
|
||||
|
||||
auto py_vmap(
|
||||
const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const py::args& args) {
|
||||
auto axes_to_flat_tree = [](const py::object& tree,
|
||||
const py::object& axes) {
|
||||
const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const nb::args& args) {
|
||||
auto axes_to_flat_tree = [](const nb::object& tree,
|
||||
const nb::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<py::object>& inputs) { return inputs[1]; });
|
||||
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
|
||||
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
@ -280,7 +288,7 @@ auto py_vmap(
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
nb::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
@ -305,24 +313,24 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
std::unordered_map<size_t, nb::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||
static std::unordered_map<size_t, nb::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
nb::callable fun;
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
nb::object captured_inputs;
|
||||
nb::object captured_outputs;
|
||||
bool shapeless;
|
||||
size_t num_outputs{0};
|
||||
mutable size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(
|
||||
const py::function& fun,
|
||||
py::object inputs,
|
||||
py::object outputs,
|
||||
const nb::callable& fun,
|
||||
nb::object inputs,
|
||||
nb::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
@ -342,7 +350,7 @@ struct PyCompiledFun {
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
|
||||
@ -358,45 +366,45 @@ struct PyCompiledFun {
|
||||
constexpr uint64_t dict_identifier = 18446744073709551521UL;
|
||||
|
||||
// Flatten the tree with hashed constants and structure
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle obj) {
|
||||
if (py::isinstance<py::list>(obj)) {
|
||||
auto l = py::cast<py::list>(obj);
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle obj) {
|
||||
if (nb::isinstance<nb::list>(obj)) {
|
||||
auto l = nb::cast<nb::list>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
recurse(l[i]);
|
||||
}
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
auto l = py::cast<py::tuple>(obj);
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
auto l = nb::cast<nb::tuple>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (auto item : obj) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
auto d = py::cast<py::dict>(obj);
|
||||
} else if (nb::isinstance<nb::dict>(obj)) {
|
||||
auto d = nb::cast<nb::dict>(obj);
|
||||
constants.push_back(dict_identifier);
|
||||
for (auto item : d) {
|
||||
auto r = py::hash(item.first);
|
||||
auto r = item.first.attr("__hash__");
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
inputs.push_back(py::cast<array>(obj));
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
inputs.push_back(nb::cast<array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
auto r = py::hash(obj);
|
||||
} else if (nb::isinstance<nb::str>(obj)) {
|
||||
auto r = obj.attr("__hash__");
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
auto r = obj.cast<int64_t>();
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
auto r = nb::cast<int64_t>(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
auto r = obj.cast<double>();
|
||||
} else if (nb::isinstance<nb::float_>(obj)) {
|
||||
auto r = nb::cast<double>(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Function arguments must be trees of arrays "
|
||||
<< "or constants (floats, ints, or strings), but received "
|
||||
<< "type " << obj.get_type() << ".";
|
||||
<< "type " << type_name_str(obj) << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
};
|
||||
@ -404,13 +412,12 @@ struct PyCompiledFun {
|
||||
recurse(args);
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
|
||||
@ -425,7 +432,7 @@ struct PyCompiledFun {
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
if (!captured_outputs.is_none()) {
|
||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||
outputs.insert(
|
||||
outputs.end(),
|
||||
@ -434,13 +441,13 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Replace tracers with originals in captured inputs
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
};
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
if (!captured_inputs.is_none()) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
@ -451,7 +458,7 @@ struct PyCompiledFun {
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
if (!captured_outputs.is_none()) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
@ -459,12 +466,16 @@ struct PyCompiledFun {
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
nb::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_from_structure(py_outputs, outputs);
|
||||
}
|
||||
|
||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||
return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
@ -476,35 +487,35 @@ struct PyCompiledFun {
|
||||
|
||||
class PyCheckpointedFun {
|
||||
public:
|
||||
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
|
||||
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
|
||||
~PyCheckpointedFun() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
py::object fun_;
|
||||
py::object args_structure_;
|
||||
std::weak_ptr<py::object> output_structure_;
|
||||
nb::object fun_;
|
||||
nb::object args_structure_;
|
||||
std::weak_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
py::object fun,
|
||||
py::object args_structure,
|
||||
std::weak_ptr<py::object> output_structure)
|
||||
nb::object fun,
|
||||
nb::object args_structure,
|
||||
std::weak_ptr<nb::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
args_structure_(std::move(args_structure)),
|
||||
output_structure_(output_structure) {}
|
||||
~InnerFunction() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
auto args = py::cast<py::tuple>(
|
||||
auto args = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(args_structure_, inputs));
|
||||
auto [outputs, output_structure] =
|
||||
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
|
||||
@ -515,9 +526,9 @@ class PyCheckpointedFun {
|
||||
}
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<py::object>();
|
||||
auto full_args = py::make_tuple(args, kwargs);
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
auto output_structure = std::make_shared<nb::object>();
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
auto [inputs, args_structure] =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
@ -527,26 +538,27 @@ class PyCheckpointedFun {
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
|
||||
return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);
|
||||
}
|
||||
|
||||
private:
|
||||
py::function fun_;
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
[](const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
nb::sig("def eval(*args) -> None"),
|
||||
R"pbdoc(
|
||||
eval(*args) -> None
|
||||
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
@ -557,19 +569,15 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
@ -577,17 +585,16 @@ void init_transforms(py::module_& m) {
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
nb::sig(
|
||||
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
R"pbdoc(
|
||||
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@ -601,19 +608,15 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
auto out = fun(*nb::cast(primals));
|
||||
if (nb::isinstance<array>(out)) {
|
||||
return std::vector<array>{nb::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
return nb::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
@ -621,16 +624,16 @@ void init_transforms(py::module_& m) {
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
nb::sig(
|
||||
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
R"pbdoc(
|
||||
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@ -644,20 +647,20 @@ void init_transforms(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return py::cpp_function(py_value_and_grad(
|
||||
return nb::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
R"pbdoc(
|
||||
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
@ -688,7 +691,7 @@ void init_transforms(py::module_& m) {
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
@ -702,34 +705,34 @@ void init_transforms(py::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which returns a tuple where the first element
|
||||
callable: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"grad",
|
||||
[](const py::function& fun,
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||
return py::cpp_function(
|
||||
[fn](const py::args& args, const py::kwargs& kwargs) {
|
||||
return nb::cpp_function(
|
||||
[fn](const nb::args& args, const nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
R"pbdoc(
|
||||
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
@ -742,26 +745,26 @@ void init_transforms(py::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which has the same input arguments as ``fun`` and
|
||||
callable: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vmap",
|
||||
[](const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
nb::sig(
|
||||
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
|
||||
R"pbdoc(
|
||||
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
|
||||
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
@ -774,16 +777,16 @@ void init_transforms(py::module_& m) {
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
[](nb::object file, const nb::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
std::ofstream out(py::cast<std::string>(file));
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
std::ofstream out(nb::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (py::hasattr(file, "write")) {
|
||||
} else if (nb::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
@ -793,57 +796,50 @@ void init_transforms(py::module_& m) {
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
"file"_a,
|
||||
"args"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs,
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& inputs,
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
std::ostringstream doc;
|
||||
auto name = fun.attr("__name__").cast<std::string>();
|
||||
doc << name;
|
||||
// Try to get the name
|
||||
auto n = fun.attr("__name__");
|
||||
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
|
||||
|
||||
// Try to get the signature
|
||||
auto inspect = py::module::import("inspect");
|
||||
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||
doc << inspect.attr("signature")(fun)
|
||||
.attr("__str__")()
|
||||
.cast<std::string>();
|
||||
std::ostringstream sig;
|
||||
sig << "def " << name;
|
||||
auto inspect = nb::module_::import_("inspect");
|
||||
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
|
||||
sig << nb::cast<std::string>(
|
||||
inspect.attr("signature")(fun).attr("__str__")());
|
||||
} else {
|
||||
sig << "(*args, **kwargs)";
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||
doc << "\n\n";
|
||||
auto dstr = d.cast<std::string>();
|
||||
// Add spaces to match first line indentation with remainder of
|
||||
// docstring
|
||||
int i = 0;
|
||||
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||
doc << ' ';
|
||||
}
|
||||
doc << dstr;
|
||||
}
|
||||
auto doc_str = doc.str();
|
||||
return py::cpp_function(
|
||||
auto d = inspect.attr("getdoc")(fun);
|
||||
std::string doc =
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
py::name(name.c_str()),
|
||||
py::doc(doc_str.c_str()));
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str());
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
"inputs"_a = nb::none(),
|
||||
"outputs"_a = nb::none(),
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
fun (callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
@ -864,15 +860,13 @@ void init_transforms(py::module_& m) {
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
callable: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
R"pbdoc(
|
||||
disable_compile() -> None
|
||||
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
@ -880,17 +874,15 @@ void init_transforms(py::module_& m) {
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compile() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
|
||||
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
auto atexit = nb::module_::import_("atexit");
|
||||
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
|
||||
}
|
||||
|
@ -2,16 +2,16 @@
|
||||
|
||||
#include "python/src/trees.h"
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
@ -23,63 +23,63 @@ void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||
int len = nb::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
|
||||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
nb::object tree_map(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<nb::object(const std::vector<nb::object>&)> transform) {
|
||||
std::function<nb::object(const std::vector<nb::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
recurse = [&](const std::vector<nb::object>& subtrees) {
|
||||
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||
nb::list l;
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
nb::list l;
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
nb::dict d;
|
||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
@ -91,7 +91,7 @@ py::object tree_map(
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
return nb::cast<nb::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
@ -99,40 +99,40 @@ py::object tree_map(
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
nb::object tree_map(
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<nb::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> visitor) {
|
||||
std::function<nb::object(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree)) {
|
||||
auto l = nb::cast<nb::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
return nb::cast<nb::object>(subtree);
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
auto d = nb::cast<nb::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return nb::cast<nb::object>(d);
|
||||
} else if (nb::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
return nb::cast<nb::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
@ -141,36 +141,36 @@ void tree_visit_update(
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
void tree_fill(nb::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
nb::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
auto arr = nb::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
return nb::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
return nb::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
|
||||
std::vector<array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
@ -180,24 +180,24 @@ std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
nb::object tree_unflatten(
|
||||
nb::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
return tree_map(tree, [&](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
return nb::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
return nb::cast<nb::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
nb::object structure_sentinel() {
|
||||
static nb::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
sentinel = nb::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
@ -206,19 +206,19 @@ py::object structure_sentinel() {
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
nb::object tree,
|
||||
bool strict /* = true */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
|
||||
if (nb::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(nb::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
return nb::cast<nb::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
@ -228,16 +228,16 @@ std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
nb::object tree_unflatten_from_structure(
|
||||
nb::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
return tree_map(structure, [&](nb::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
return nb::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
return nb::cast<nb::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1,38 +1,37 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform);
|
||||
nb::object tree_map(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<nb::object(const std::vector<nb::object>&)> transform);
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform);
|
||||
nb::object tree_map(
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> transform);
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor);
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> visitor);
|
||||
|
||||
/**
|
||||
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||
* given arrays. */
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values);
|
||||
void tree_fill(nb::object& tree, const std::vector<array>& values);
|
||||
|
||||
/**
|
||||
* Replace all the arrays from the src values with the dst values in the
|
||||
* tree.
|
||||
*/
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
nb::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst);
|
||||
|
||||
@ -40,21 +39,21 @@ void tree_replace(
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true);
|
||||
std::vector<array> tree_flatten(nb::object tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
*/
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
nb::object tree_unflatten(
|
||||
nb::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
std::pair<std::vector<array>, nb::object> tree_flatten_with_structure(
|
||||
nb::object tree,
|
||||
bool strict = true);
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
nb::object tree_unflatten_from_structure(
|
||||
nb::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
||||
|
@ -1,81 +0,0 @@
|
||||
|
||||
#include "mlx/utils.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
// Slightly different from the original, with python context on init we are not
|
||||
// in the context yet. Only create the inner context on enter then delete on
|
||||
// exit.
|
||||
class PyStreamContext {
|
||||
public:
|
||||
PyStreamContext(StreamOrDevice s) : _inner(nullptr) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
_s = s;
|
||||
}
|
||||
|
||||
void enter() {
|
||||
_inner = new StreamContext(_s);
|
||||
}
|
||||
|
||||
void exit() {
|
||||
if (_inner != nullptr) {
|
||||
delete _inner;
|
||||
_inner = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
StreamOrDevice _s;
|
||||
StreamContext* _inner;
|
||||
};
|
||||
|
||||
void init_utils(py::module_& m) {
|
||||
py::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
|
||||
A context manager for setting the current device and stream.
|
||||
|
||||
See :func:`stream` for usage.
|
||||
|
||||
Args:
|
||||
s: The stream or device to set as the default.
|
||||
)pbdoc")
|
||||
.def(py::init<StreamOrDevice>(), "s"_a)
|
||||
.def("__enter__", [](PyStreamContext& scm) { scm.enter(); })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](PyStreamContext& scm,
|
||||
const std::optional<py::type>& exc_type,
|
||||
const std::optional<py::object>& exc_value,
|
||||
const std::optional<py::object>& traceback) { scm.exit(); });
|
||||
m.def(
|
||||
"stream",
|
||||
[](StreamOrDevice s) { return PyStreamContext(s); },
|
||||
"s"_a,
|
||||
R"pbdoc(
|
||||
Create a context manager to set the default device and stream.
|
||||
|
||||
Args:
|
||||
s: The :obj:`Stream` or :obj:`Device` to set as the default.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the default device and stream.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Create a context manager for the default device and stream.
|
||||
with mx.stream(mx.cpu):
|
||||
# Operations here will use mx.cpu by default.
|
||||
pass
|
||||
)pbdoc");
|
||||
}
|
@ -1,23 +1,22 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/complex.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/complex.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray = std::
|
||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||
static constexpr std::monostate none{};
|
||||
variant<nb::bool_, nb::int_, nb::float_, std::complex<float>, nb::object>;
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
@ -32,31 +31,36 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array_with_accessor(py::object obj) {
|
||||
if (py::hasattr(obj, "__mlx_array__")) {
|
||||
return obj.attr("__mlx_array__")().cast<array>();
|
||||
inline array to_array_with_accessor(nb::object obj) {
|
||||
if (nb::hasattr(obj, "__mlx_array__")) {
|
||||
return nb::cast<array>(obj.attr("__mlx_array__")());
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return nb::cast<array>(obj);
|
||||
} else {
|
||||
return obj.cast<array>();
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type " << nb::type_name(obj.type()).c_str()
|
||||
<< " received in array initialization.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), dtype.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return array(nb::cast<bool>(*pv), dtype.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(int32);
|
||||
// bool_ is an exception and is always promoted
|
||||
return array(py::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
return array(nb::cast<int>(*pv), (out_t == bool_) ? int32 : out_t);
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(float32);
|
||||
return array(
|
||||
py::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
nb::cast<float>(*pv), is_floating_point(out_t) ? out_t : float32);
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), complex64);
|
||||
} else {
|
||||
return to_array_with_accessor(std::get<py::object>(v));
|
||||
return to_array_with_accessor(std::get<nb::object>(v));
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,14 +72,14 @@ inline std::pair<array, array> to_arrays(
|
||||
// - If a is an array but b is not, treat b as a weak python type
|
||||
// - If b is an array but a is not, treat a as a weak python type
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
if (auto pa = std::get_if<py::object>(&a); pa) {
|
||||
if (auto pa = std::get_if<nb::object>(&a); pa) {
|
||||
auto arr_a = to_array_with_accessor(*pa);
|
||||
if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||
if (auto pb = std::get_if<nb::object>(&b); pb) {
|
||||
auto arr_b = to_array_with_accessor(*pb);
|
||||
return {arr_a, arr_b};
|
||||
}
|
||||
return {arr_a, to_array(b, arr_a.dtype())};
|
||||
} else if (auto pb = std::get_if<py::object>(&b); pb) {
|
||||
} else if (auto pb = std::get_if<nb::object>(&b); pb) {
|
||||
auto arr_b = to_array_with_accessor(*pb);
|
||||
return {to_array(a, arr_b.dtype()), arr_b};
|
||||
} else {
|
||||
|
@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.dtype, mx.bool_)
|
||||
self.assertEqual(y.item(), True)
|
||||
|
||||
# y = mx.array(x, mx.complex64)
|
||||
# self.assertEqual(y.dtype, mx.complex64)
|
||||
# self.assertEqual(y.item(), 3.0+0j)
|
||||
y = mx.array(x, mx.complex64)
|
||||
self.assertEqual(y.dtype, mx.complex64)
|
||||
self.assertEqual(y.item(), 3.0 + 0j)
|
||||
|
||||
def test_array_repr(self):
|
||||
x = mx.array(True)
|
||||
@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
# check if it throws an error when dtype is not supported (bfloat16)
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(TypeError):
|
||||
pickle.dumps(x)
|
||||
|
||||
def test_array_copy(self):
|
||||
@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(y, x - 1)
|
||||
|
||||
def test_indexing(self):
|
||||
# Only ellipsis is a no-op
|
||||
a_mlx = mx.array([1])[...]
|
||||
self.assertEqual(a_mlx.shape, (1,))
|
||||
self.assertEqual(a_mlx.item(), 1)
|
||||
|
||||
# Basic content check, slice indexing
|
||||
a_npy = np.arange(64, dtype=np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
|
||||
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||
a_tf = tf.constant(a_np, dtype=tf_dtype)
|
||||
a_mx = mx.array(a_tf)
|
||||
a_mx = mx.array(np.array(a_tf))
|
||||
for f in [
|
||||
lambda x: x,
|
||||
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
|
||||
|
6
setup.py
6
setup.py
@ -134,7 +134,9 @@ class GenerateStubs(Command):
|
||||
pass
|
||||
|
||||
def run(self) -> None:
|
||||
subprocess.run(["pybind11-stubgen", "mlx.core", "-o", "python"])
|
||||
subprocess.run(
|
||||
["python", "-m", "nanobind.stubgen", "-m", "mlx.core", "-r", "-O", "python"]
|
||||
)
|
||||
|
||||
|
||||
# Read the content of README.md
|
||||
@ -165,7 +167,7 @@ if __name__ == "__main__":
|
||||
include_package_data=True,
|
||||
extras_require={
|
||||
"testing": ["numpy", "torch"],
|
||||
"dev": ["pre-commit", "pybind11-stubgen"],
|
||||
"dev": ["pre-commit"],
|
||||
},
|
||||
ext_modules=[CMakeExtension("mlx.core")],
|
||||
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
|
||||
|
Loading…
Reference in New Issue
Block a user