Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,8 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/stl/vector.h>
#include <cstring>
#include <fstream>
#include <stdexcept>
@@ -16,39 +14,39 @@
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
///////////////////////////////////////////////////////////////////////////////
bool is_istream_object(const py::object& file) {
return py::hasattr(file, "readinto") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
bool is_istream_object(const nb::object& file) {
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
}
bool is_ostream_object(const py::object& file) {
return py::hasattr(file, "write") && py::hasattr(file, "seek") &&
py::hasattr(file, "tell") && py::hasattr(file, "closed");
bool is_ostream_object(const nb::object& file) {
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
}
bool is_zip_file(const py::module_& zipfile, const py::object& file) {
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
if (is_istream_object(file)) {
auto st_pos = file.attr("tell")();
bool r = (zipfile.attr("is_zipfile")(file)).cast<bool>();
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
file.attr("seek")(st_pos, 0);
return r;
}
return zipfile.attr("is_zipfile")(file).cast<bool>();
return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
}
class ZipFileWrapper {
public:
ZipFileWrapper(
const py::module_& zipfile,
const py::object& file,
const nb::module_& zipfile,
const nb::object& file,
char mode = 'r',
int compression = 0)
: zipfile_module_(zipfile),
@@ -63,10 +61,10 @@ class ZipFileWrapper {
close_func_(zipfile_object_.attr("close")) {}
std::vector<std::string> namelist() const {
return files_list_.cast<std::vector<std::string>>();
return nb::cast<std::vector<std::string>>(files_list_);
}
py::object open(const std::string& key, char mode = 'r') {
nb::object open(const std::string& key, char mode = 'r') {
// Following numpy :
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
if (mode == 'w') {
@@ -76,12 +74,12 @@ class ZipFileWrapper {
}
private:
py::module_ zipfile_module_;
py::object zipfile_object_;
py::list files_list_;
py::object open_func_;
py::object read_func_;
py::object close_func_;
nb::module_ zipfile_module_;
nb::object zipfile_object_;
nb::list files_list_;
nb::object open_func_;
nb::object read_func_;
nb::object close_func_;
};
///////////////////////////////////////////////////////////////////////////////
@@ -90,14 +88,14 @@ class ZipFileWrapper {
class PyFileReader : public io::Reader {
public:
PyFileReader(py::object file)
PyFileReader(nb::object file)
: pyistream_(file),
readinto_func_(file.attr("readinto")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileReader() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
pyistream_.release().dec_ref();
readinto_func_.release().dec_ref();
@@ -108,8 +106,8 @@ class PyFileReader : public io::Reader {
bool is_open() const override {
bool out;
{
py::gil_scoped_acquire gil;
out = !pyistream_.attr("closed").cast<bool>();
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyistream_.attr("closed"));
}
return out;
}
@@ -117,7 +115,7 @@ class PyFileReader : public io::Reader {
bool good() const override {
bool out;
{
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
out = !pyistream_.is_none();
}
return out;
@@ -126,25 +124,24 @@ class PyFileReader : public io::Reader {
size_t tell() const override {
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void read(char* data, size_t n) override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
py::object bytes_read =
readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
if (bytes_read.is_none() || py::cast<size_t>(bytes_read) < n) {
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
throw std::runtime_error("[load] Failed to read from python stream");
}
}
@@ -154,23 +151,23 @@ class PyFileReader : public io::Reader {
}
private:
py::object pyistream_;
py::object readinto_func_;
py::object seek_func_;
py::object tell_func_;
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;
nb::object tell_func_;
};
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return load_safetensors(py::cast<std::string>(file), s);
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
arr.eval();
}
@@ -182,20 +179,20 @@ mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
"[load_safetensors] Input must be a file-like object, or string");
}
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s);
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
nb::object file,
StreamOrDevice s) {
bool own_file = py::isinstance<py::str>(file);
bool own_file = nb::isinstance<nb::str>(file);
py::module_ zipfile = py::module_::import("zipfile");
nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be "
@@ -208,7 +205,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
py::object sub_file = zipfile_object.open(st);
nb::object sub_file = zipfile_object.open(st);
// Create array from python fille stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
@@ -224,7 +221,7 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
// If we don't own the stream and it was passed to us, eval immediately
if (!own_file) {
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
for (auto& [key, arr] : array_dict) {
arr.eval();
}
@@ -233,14 +230,14 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
return array_dict;
}
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return load(py::cast<std::string>(file), s);
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
arr.eval();
}
return arr;
@@ -250,16 +247,16 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
}
LoadOutputTypes mlx_load_helper(
py::object file,
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (py::isinstance<py::str>(file)) {
fname = py::cast<std::string>(file);
if (nb::isinstance<nb::str>(file)) {
fname = nb::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = file.attr("name").cast<std::string>();
fname = nb::cast<std::string>(file.attr("name"));
} else {
throw std::invalid_argument(
"[load] Input must be a file-like object opened in binary mode, or string");
@@ -304,14 +301,14 @@ LoadOutputTypes mlx_load_helper(
class PyFileWriter : public io::Writer {
public:
PyFileWriter(py::object file)
PyFileWriter(nb::object file)
: pyostream_(file),
write_func_(file.attr("write")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileWriter() {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
pyostream_.release().dec_ref();
write_func_.release().dec_ref();
@@ -322,8 +319,8 @@ class PyFileWriter : public io::Writer {
bool is_open() const override {
bool out;
{
py::gil_scoped_acquire gil;
out = !pyostream_.attr("closed").cast<bool>();
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyostream_.attr("closed"));
}
return out;
}
@@ -331,7 +328,7 @@ class PyFileWriter : public io::Writer {
bool good() const override {
bool out;
{
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
out = !pyostream_.is_none();
}
return out;
@@ -340,25 +337,26 @@ class PyFileWriter : public io::Writer {
size_t tell() const override {
size_t out;
{
py::gil_scoped_acquire gil;
out = tell_func_().cast<size_t>();
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
seek_func_(off, (int)way);
}
void write(const char* data, size_t n) override {
py::gil_scoped_acquire gil;
nb::gil_scoped_acquire gil;
py::object bytes_written =
write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)}));
auto memview =
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
nb::object bytes_written = write_func_(nb::handle(memview));
if (bytes_written.is_none() || py::cast<size_t>(bytes_written) < n) {
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
throw std::runtime_error("[load] Failed to write to python stream");
}
}
@@ -368,20 +366,20 @@ class PyFileWriter : public io::Writer {
}
private:
py::object pyostream_;
py::object write_func_;
py::object seek_func_;
py::object tell_func_;
nb::object pyostream_;
nb::object write_func_;
nb::object seek_func_;
nb::object tell_func_;
};
void mlx_save_helper(py::object file, array a) {
if (py::isinstance<py::str>(file)) {
save(py::cast<std::string>(file), a);
void mlx_save_helper(nb::object file, array a) {
if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a);
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release gil;
nb::gil_scoped_release gil;
save(writer, a);
}
@@ -393,26 +391,26 @@ void mlx_save_helper(py::object file, array a) {
}
void mlx_savez_helper(
py::object file_,
py::args args,
const py::kwargs& kwargs,
nb::object file_,
nb::args args,
const nb::kwargs& kwargs,
bool compressed) {
// Add .npz to the end of the filename if not already there
py::object file = file_;
nb::object file = file_;
if (py::isinstance<py::str>(file_)) {
std::string fname = file_.cast<std::string>();
if (nb::isinstance<nb::str>(file_)) {
std::string fname = nb::cast<std::string>(file_);
// Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
fname += ".npz";
file = py::str(fname);
file = nb::cast(fname);
}
// Collect args and kwargs
auto arrays_dict = kwargs.cast<std::unordered_map<std::string, array>>();
auto arrays_list = args.cast<std::vector<array>>();
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = nb::cast<std::vector<array>>(args);
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
@@ -426,9 +424,9 @@ void mlx_savez_helper(
}
// Create python ZipFile object depending on compression
py::module_ zipfile = py::module_::import("zipfile");
int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast<int>()
: zipfile.attr("ZIP_STORED").cast<int>();
nb::module_ zipfile = nb::module_::import_("zipfile");
int compression = nb::cast<int>(
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
char mode = 'w';
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
@@ -438,7 +436,7 @@ void mlx_savez_helper(
auto py_ostream = zipfile_object.open(fname, 'w');
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
py::gil_scoped_release nogil;
nb::gil_scoped_release nogil;
save(writer, a);
}
}
@@ -447,31 +445,31 @@ void mlx_savez_helper(
}
void mlx_save_safetensor_helper(
py::object file,
py::dict d,
std::optional<py::dict> m) {
nb::object file,
nb::dict d,
std::optional<nb::dict> m) {
std::unordered_map<std::string, std::string> metadata_map;
if (m) {
try {
metadata_map =
m.value().cast<std::unordered_map<std::string, std::string>>();
} catch (const py::cast_error& e) {
nb::cast<std::unordered_map<std::string, std::string>>(m.value());
} catch (const nb::cast_error& e) {
throw std::invalid_argument(
"[save_safetensors] Metadata must be a dictionary with string keys and values");
}
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
if (nb::isinstance<nb::str>(file)) {
{
py::gil_scoped_release nogil;
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
py::gil_scoped_release nogil;
nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map);
}
} else {
@@ -481,22 +479,22 @@ void mlx_save_safetensor_helper(
}
void mlx_save_gguf_helper(
py::object file,
py::dict a,
std::optional<py::dict> m) {
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
nb::object file,
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (m) {
auto metadata_map =
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
{
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
}
} else {
{
py::gil_scoped_release nogil;
save_gguf(py::cast<std::string>(file), arrays_map);
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map);
}
}
} else {