mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-07 04:38:13 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
@@ -16,39 +14,39 @@
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Helpers
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool is_istream_object(const py::object& file) {
|
||||
return py::hasattr(file, "readinto") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
bool is_istream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_ostream_object(const py::object& file) {
|
||||
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
|
||||
py::hasattr(file, "tell") && py::hasattr(file, "closed");
|
||||
bool is_ostream_object(const nb::object& file) {
|
||||
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
|
||||
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
|
||||
}
|
||||
|
||||
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
|
||||
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
|
||||
if (is_istream_object(file)) {
|
||||
auto st_pos = file.attr("tell")();
|
||||
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
|
||||
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
|
||||
file.attr("seek")(st_pos, 0);
|
||||
return r;
|
||||
}
|
||||
return zipfile.attr("is_zipfile")(file).cast<bool>();
|
||||
return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
|
||||
}
|
||||
|
||||
class ZipFileWrapper {
|
||||
public:
|
||||
ZipFileWrapper(
|
||||
const py::module_& zipfile,
|
||||
const py::object& file,
|
||||
const nb::module_& zipfile,
|
||||
const nb::object& file,
|
||||
char mode = 'r',
|
||||
int compression = 0)
|
||||
: zipfile_module_(zipfile),
|
||||
@@ -63,10 +61,10 @@ class ZipFileWrapper {
|
||||
close_func_(zipfile_object_.attr("close")) {}
|
||||
|
||||
std::vector<std::string> namelist() const {
|
||||
return files_list_.cast<std::vector<std::string>>();
|
||||
return nb::cast<std::vector<std::string>>(files_list_);
|
||||
}
|
||||
|
||||
py::object open(const std::string& key, char mode = 'r') {
|
||||
nb::object open(const std::string& key, char mode = 'r') {
|
||||
// Following numpy :
|
||||
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
|
||||
if (mode == 'w') {
|
||||
@@ -76,12 +74,12 @@ class ZipFileWrapper {
|
||||
}
|
||||
|
||||
private:
|
||||
py::module_ zipfile_module_;
|
||||
py::object zipfile_object_;
|
||||
py::list files_list_;
|
||||
py::object open_func_;
|
||||
py::object read_func_;
|
||||
py::object close_func_;
|
||||
nb::module_ zipfile_module_;
|
||||
nb::object zipfile_object_;
|
||||
nb::list files_list_;
|
||||
nb::object open_func_;
|
||||
nb::object read_func_;
|
||||
nb::object close_func_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -90,14 +88,14 @@ class ZipFileWrapper {
|
||||
|
||||
class PyFileReader : public io::Reader {
|
||||
public:
|
||||
PyFileReader(py::object file)
|
||||
PyFileReader(nb::object file)
|
||||
: pyistream_(file),
|
||||
readinto_func_(file.attr("readinto")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileReader() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
pyistream_.release().dec_ref();
|
||||
readinto_func_.release().dec_ref();
|
||||
@@ -108,8 +106,8 @@ class PyFileReader : public io::Reader {
|
||||
bool is_open() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyistream_.attr("closed").cast<bool>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !nb::cast<bool>(pyistream_.attr("closed"));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -117,7 +115,7 @@ class PyFileReader : public io::Reader {
|
||||
bool good() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !pyistream_.is_none();
|
||||
}
|
||||
return out;
|
||||
@@ -126,25 +124,24 @@ class PyFileReader : public io::Reader {
|
||||
size_t tell() const override {
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = nb::cast<size_t>(tell_func_());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
||||
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
||||
|
||||
py::object bytes_read =
|
||||
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
|
||||
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
|
||||
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
|
||||
throw std::runtime_error("[load] Failed to read from python stream");
|
||||
}
|
||||
}
|
||||
@@ -154,23 +151,23 @@ class PyFileReader : public io::Reader {
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyistream_;
|
||||
py::object readinto_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
nb::object pyistream_;
|
||||
nb::object readinto_func_;
|
||||
nb::object seek_func_;
|
||||
nb::object tell_func_;
|
||||
};
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(py::cast<std::string>(file), s);
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
arr.eval();
|
||||
}
|
||||
@@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
py::object file,
|
||||
nb::object file,
|
||||
StreamOrDevice s) {
|
||||
bool own_file = py::isinstance<py::str>(file);
|
||||
bool own_file = nb::isinstance<nb::str>(file);
|
||||
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
if (!is_zip_file(zipfile, file)) {
|
||||
throw std::invalid_argument(
|
||||
"[load_npz] Input must be a zip file or a file-like object that can be "
|
||||
@@ -208,7 +205,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
ZipFileWrapper zipfile_object(zipfile, file);
|
||||
for (const std::string& st : zipfile_object.namelist()) {
|
||||
// Open zip file as a python file stream
|
||||
py::object sub_file = zipfile_object.open(st);
|
||||
nb::object sub_file = zipfile_object.open(st);
|
||||
|
||||
// Create array from python fille stream
|
||||
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
|
||||
@@ -224,7 +221,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
if (!own_file) {
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : array_dict) {
|
||||
arr.eval();
|
||||
}
|
||||
@@ -233,14 +230,14 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
return array_dict;
|
||||
}
|
||||
|
||||
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
|
||||
return load(py::cast<std::string>(file), s);
|
||||
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
|
||||
return load(nb::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
arr.eval();
|
||||
}
|
||||
return arr;
|
||||
@@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
nb::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
fname = py::cast<std::string>(file);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
fname = nb::cast<std::string>(file);
|
||||
} else if (is_istream_object(file)) {
|
||||
fname = file.attr("name").cast<std::string>();
|
||||
fname = nb::cast<std::string>(file.attr("name"));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[load] Input must be a file-like object opened in binary mode, or string");
|
||||
@@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper(
|
||||
|
||||
class PyFileWriter : public io::Writer {
|
||||
public:
|
||||
PyFileWriter(py::object file)
|
||||
PyFileWriter(nb::object file)
|
||||
: pyostream_(file),
|
||||
write_func_(file.attr("write")),
|
||||
seek_func_(file.attr("seek")),
|
||||
tell_func_(file.attr("tell")) {}
|
||||
|
||||
~PyFileWriter() {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
pyostream_.release().dec_ref();
|
||||
write_func_.release().dec_ref();
|
||||
@@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer {
|
||||
bool is_open() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = !pyostream_.attr("closed").cast<bool>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !nb::cast<bool>(pyostream_.attr("closed"));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer {
|
||||
bool good() const override {
|
||||
bool out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = !pyostream_.is_none();
|
||||
}
|
||||
return out;
|
||||
@@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer {
|
||||
size_t tell() const override {
|
||||
size_t out;
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
out = tell_func_().cast<size_t>();
|
||||
nb::gil_scoped_acquire gil;
|
||||
out = nb::cast<size_t>(tell_func_());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(off, (int)way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
py::gil_scoped_acquire gil;
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
py::object bytes_written =
|
||||
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
|
||||
auto memview =
|
||||
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
|
||||
nb::object bytes_written = write_func_(nb::handle(memview));
|
||||
|
||||
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
|
||||
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
|
||||
throw std::runtime_error("[load] Failed to write to python stream");
|
||||
}
|
||||
}
|
||||
@@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
py::object pyostream_;
|
||||
py::object write_func_;
|
||||
py::object seek_func_;
|
||||
py::object tell_func_;
|
||||
nb::object pyostream_;
|
||||
nb::object write_func_;
|
||||
nb::object seek_func_;
|
||||
nb::object tell_func_;
|
||||
};
|
||||
|
||||
void mlx_save_helper(py::object file, array a) {
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save(py::cast<std::string>(file), a);
|
||||
void mlx_save_helper(nb::object file, array a) {
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
save(nb::cast<std::string>(file), a);
|
||||
return;
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
nb::gil_scoped_release gil;
|
||||
save(writer, a);
|
||||
}
|
||||
|
||||
@@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) {
|
||||
}
|
||||
|
||||
void mlx_savez_helper(
|
||||
py::object file_,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
nb::object file_,
|
||||
nb::args args,
|
||||
const nb::kwargs& kwargs,
|
||||
bool compressed) {
|
||||
// Add .npz to the end of the filename if not already there
|
||||
py::object file = file_;
|
||||
nb::object file = file_;
|
||||
|
||||
if (py::isinstance<py::str>(file_)) {
|
||||
std::string fname = file_.cast<std::string>();
|
||||
if (nb::isinstance<nb::str>(file_)) {
|
||||
std::string fname = nb::cast<std::string>(file_);
|
||||
|
||||
// Add .npz to file name if it is not there
|
||||
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
|
||||
fname += ".npz";
|
||||
|
||||
file = py::str(fname);
|
||||
file = nb::cast(fname);
|
||||
}
|
||||
|
||||
// Collect args and kwargs
|
||||
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
|
||||
auto arrays_list = args.cast<std::vector<array>>();
|
||||
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
|
||||
auto arrays_list = nb::cast<std::vector<array>>(args);
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
@@ -426,9 +424,9 @@ void mlx_savez_helper(
|
||||
}
|
||||
|
||||
// Create python ZipFile object depending on compression
|
||||
py::module_ zipfile = py::module_::import("zipfile");
|
||||
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
|
||||
: zipfile.attr("ZIP_STORED").cast<int>();
|
||||
nb::module_ zipfile = nb::module_::import_("zipfile");
|
||||
int compression = nb::cast<int>(
|
||||
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
|
||||
char mode = 'w';
|
||||
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
|
||||
|
||||
@@ -438,7 +436,7 @@ void mlx_savez_helper(
|
||||
auto py_ostream = zipfile_object.open(fname, 'w');
|
||||
auto writer = std::make_shared<PyFileWriter>(py_ostream);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
save(writer, a);
|
||||
}
|
||||
}
|
||||
@@ -447,31 +445,31 @@ void mlx_savez_helper(
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<py::dict> m) {
|
||||
nb::object file,
|
||||
nb::dict d,
|
||||
std::optional<nb::dict> m) {
|
||||
std::unordered_map<std::string, std::string> metadata_map;
|
||||
if (m) {
|
||||
try {
|
||||
metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, std::string>>();
|
||||
} catch (const py::cast_error& e) {
|
||||
nb::cast<std::unordered_map<std::string, std::string>>(m.value());
|
||||
} catch (const nb::cast_error& e) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] Metadata must be a dictionary with string keys and values");
|
||||
}
|
||||
} else {
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
nb::gil_scoped_release nogil;
|
||||
save_safetensors(writer, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
@@ -481,22 +479,22 @@ void mlx_save_safetensor_helper(
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict a,
|
||||
std::optional<py::dict> m) {
|
||||
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
nb::object file,
|
||||
nb::dict a,
|
||||
std::optional<nb::dict> m) {
|
||||
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
|
||||
if (nb::isinstance<nb::str>(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
|
||||
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
nb::gil_scoped_release nogil;
|
||||
save_gguf(nb::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user