mlx/python/src/load.cpp

504 lines
14 KiB
C++
Raw Normal View History

// Copyright © 2023-2024 Apple Inc.
2023-11-30 02:30:41 +08:00
#include <nanobind/stl/vector.h>
2023-11-30 02:30:41 +08:00
#include <cstring>
#include <fstream>
#include <stdexcept>
#include <string_view>
#include <unordered_map>
#include <vector>
#include "mlx/io/load.h"
2023-11-30 02:30:41 +08:00
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace nb = nanobind;
using namespace nb::literals;
2023-11-30 02:30:41 +08:00
using namespace mlx::core;
///////////////////////////////////////////////////////////////////////////////
// Helpers
///////////////////////////////////////////////////////////////////////////////
bool is_istream_object(const nb::object& file) {
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
2023-11-30 02:30:41 +08:00
}
bool is_ostream_object(const nb::object& file) {
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") &&
nb::hasattr(file, "tell") && nb::hasattr(file, "closed");
2023-11-30 02:30:41 +08:00
}
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) {
2023-11-30 02:30:41 +08:00
if (is_istream_object(file)) {
auto st_pos = file.attr("tell")();
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file));
2023-11-30 02:30:41 +08:00
file.attr("seek")(st_pos, 0);
return r;
}
return nb::cast<bool>(zipfile.attr("is_zipfile")(file));
2023-11-30 02:30:41 +08:00
}
class ZipFileWrapper {
public:
ZipFileWrapper(
const nb::module_& zipfile,
const nb::object& file,
2023-11-30 02:30:41 +08:00
char mode = 'r',
int compression = 0)
: zipfile_module_(zipfile),
zipfile_object_(zipfile.attr("ZipFile")(
file,
"mode"_a = mode,
"compression"_a = compression,
"allowZip64"_a = true)),
files_list_(zipfile_object_.attr("namelist")()),
open_func_(zipfile_object_.attr("open")),
read_func_(zipfile_object_.attr("read")),
close_func_(zipfile_object_.attr("close")) {}
std::vector<std::string> namelist() const {
return nb::cast<std::vector<std::string>>(files_list_);
2023-11-30 02:30:41 +08:00
}
nb::object open(const std::string& key, char mode = 'r') {
2023-11-30 02:30:41 +08:00
// Following numpy :
// https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47
if (mode == 'w') {
return open_func_(key, "mode"_a = mode, "force_zip64"_a = true);
}
return open_func_(key, "mode"_a = mode);
}
private:
nb::module_ zipfile_module_;
nb::object zipfile_object_;
nb::list files_list_;
nb::object open_func_;
nb::object read_func_;
nb::object close_func_;
2023-11-30 02:30:41 +08:00
};
///////////////////////////////////////////////////////////////////////////////
// Loading
///////////////////////////////////////////////////////////////////////////////
class PyFileReader : public io::Reader {
public:
PyFileReader(nb::object file)
2023-11-30 02:30:41 +08:00
: pyistream_(file),
readinto_func_(file.attr("readinto")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileReader() {
nb::gil_scoped_acquire gil;
pyistream_.release().dec_ref();
readinto_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
2023-11-30 02:30:41 +08:00
bool is_open() const override {
bool out;
{
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyistream_.attr("closed"));
}
return out;
2023-11-30 02:30:41 +08:00
}
bool good() const override {
bool out;
{
nb::gil_scoped_acquire gil;
out = !pyistream_.is_none();
}
return out;
2023-11-30 02:30:41 +08:00
}
size_t tell() override {
size_t out;
{
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
2023-11-30 02:30:41 +08:00
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
nb::gil_scoped_acquire gil;
2023-11-30 02:30:41 +08:00
seek_func_(off, (int)way);
}
void read(char* data, size_t n) override {
nb::gil_scoped_acquire gil;
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
nb::object bytes_read = readinto_func_(nb::handle(memview));
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) {
2023-11-30 02:30:41 +08:00
throw std::runtime_error("[load] Failed to read from python stream");
}
}
std::string label() const override {
return "python file object";
}
private:
nb::object pyistream_;
nb::object readinto_func_;
nb::object seek_func_;
nb::object tell_func_;
2023-11-30 02:30:41 +08:00
};
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
arr.eval();
}
}
return res;
}
throw std::invalid_argument(
"[load_safetensors] Input must be a file-like object, or string");
}
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
nb::object file,
StreamOrDevice s) {
bool own_file = nb::isinstance<nb::str>(file);
2024-01-27 14:03:52 +08:00
nb::module_ zipfile = nb::module_::import_("zipfile");
if (!is_zip_file(zipfile, file)) {
throw std::invalid_argument(
"[load_npz] Input must be a zip file or a file-like object that can be "
"opened with zipfile.ZipFile");
}
// Output dictionary filename in zip -> loaded array
std::unordered_map<std::string, array> array_dict;
2023-11-30 02:30:41 +08:00
// Create python ZipFile object
ZipFileWrapper zipfile_object(zipfile, file);
for (const std::string& st : zipfile_object.namelist()) {
// Open zip file as a python file stream
nb::object sub_file = zipfile_object.open(st);
2023-11-30 02:30:41 +08:00
2024-03-19 21:15:17 +08:00
// Create array from python file stream
auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
2023-11-30 02:30:41 +08:00
// Remove .npy from file if it is there
auto key = st;
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
key = st.substr(0, st.length() - 4);
2023-11-30 02:30:41 +08:00
// Add array to dict
array_dict.insert({key, arr});
}
2023-11-30 02:30:41 +08:00
// If we don't own the stream and it was passed to us, eval immediately
2024-01-27 14:03:52 +08:00
if (!own_file) {
nb::gil_scoped_release gil;
2024-01-27 14:03:52 +08:00
for (auto& [key, arr] : array_dict) {
arr.eval();
}
}
2023-11-30 02:30:41 +08:00
2023-12-27 22:20:45 +08:00
return array_dict;
}
2023-11-30 02:30:41 +08:00
array mlx_load_npy_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .npy file path string
return load(nb::cast<std::string>(file), s);
2023-11-30 02:30:41 +08:00
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
{
nb::gil_scoped_release gil;
arr.eval();
}
2023-12-27 22:20:45 +08:00
return arr;
2023-11-30 02:30:41 +08:00
}
throw std::invalid_argument(
"[load_npy] Input must be a file-like object, or string");
}
LoadOutputTypes mlx_load_helper(
nb::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
if (nb::isinstance<nb::str>(file)) {
fname = nb::cast<std::string>(file);
} else if (is_istream_object(file)) {
fname = nb::cast<std::string>(file.attr("name"));
} else {
throw std::invalid_argument(
2024-01-27 14:03:52 +08:00
"[load] Input must be a file-like object opened in binary mode, or string");
}
size_t ext = fname.find_last_of('.');
if (ext == std::string::npos) {
throw std::invalid_argument(
"[load] Could not infer file format from extension");
}
format.emplace(fname.substr(ext + 1));
}
if (return_metadata && (format.value() == "npy" || format.value() == "npz")) {
throw std::invalid_argument(
"[load] metadata not supported for format " + format.value());
}
if (format.value() == "safetensors") {
auto [dict, metadata] = mlx_load_safetensor_helper(file, s);
if (return_metadata) {
return std::make_pair(dict, metadata);
}
return dict;
} else if (format.value() == "npz") {
return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else if (format.value() == "gguf") {
auto [weights, metadata] = mlx_load_gguf_helper(file, s);
if (return_metadata) {
return std::make_pair(weights, metadata);
} else {
return weights;
}
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
2023-11-30 02:30:41 +08:00
}
///////////////////////////////////////////////////////////////////////////////
// Saving
///////////////////////////////////////////////////////////////////////////////
class PyFileWriter : public io::Writer {
public:
PyFileWriter(nb::object file)
2023-11-30 02:30:41 +08:00
: pyostream_(file),
write_func_(file.attr("write")),
seek_func_(file.attr("seek")),
tell_func_(file.attr("tell")) {}
~PyFileWriter() {
nb::gil_scoped_acquire gil;
pyostream_.release().dec_ref();
write_func_.release().dec_ref();
seek_func_.release().dec_ref();
tell_func_.release().dec_ref();
}
2023-11-30 02:30:41 +08:00
bool is_open() const override {
bool out;
{
nb::gil_scoped_acquire gil;
out = !nb::cast<bool>(pyostream_.attr("closed"));
}
return out;
2023-11-30 02:30:41 +08:00
}
bool good() const override {
bool out;
{
nb::gil_scoped_acquire gil;
out = !pyostream_.is_none();
}
return out;
2023-11-30 02:30:41 +08:00
}
size_t tell() override {
size_t out;
{
nb::gil_scoped_acquire gil;
out = nb::cast<size_t>(tell_func_());
}
return out;
2023-11-30 02:30:41 +08:00
}
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
override {
nb::gil_scoped_acquire gil;
2023-11-30 02:30:41 +08:00
seek_func_(off, (int)way);
}
void write(const char* data, size_t n) override {
nb::gil_scoped_acquire gil;
auto memview =
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ);
nb::object bytes_written = write_func_(nb::handle(memview));
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) {
2023-11-30 02:30:41 +08:00
throw std::runtime_error("[load] Failed to write to python stream");
}
}
std::string label() const override {
return "python file object";
}
private:
nb::object pyostream_;
nb::object write_func_;
nb::object seek_func_;
nb::object tell_func_;
2023-11-30 02:30:41 +08:00
};
void mlx_save_helper(nb::object file, array a) {
if (nb::isinstance<nb::str>(file)) {
save(nb::cast<std::string>(file), a);
2023-11-30 02:30:41 +08:00
return;
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release gil;
save(writer, a);
}
2023-11-30 02:30:41 +08:00
return;
}
throw std::invalid_argument(
"[save] Input must be a file-like object, or string");
2023-11-30 02:30:41 +08:00
}
void mlx_savez_helper(
nb::object file_,
nb::args args,
const nb::kwargs& kwargs,
2023-11-30 02:30:41 +08:00
bool compressed) {
// Add .npz to the end of the filename if not already there
nb::object file = file_;
2023-11-30 02:30:41 +08:00
if (nb::isinstance<nb::str>(file_)) {
std::string fname = nb::cast<std::string>(file_);
2023-11-30 02:30:41 +08:00
// Add .npz to file name if it is not there
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz")
fname += ".npz";
file = nb::cast(fname);
2023-11-30 02:30:41 +08:00
}
// Collect args and kwargs
auto arrays_dict = nb::cast<std::unordered_map<std::string, array>>(kwargs);
auto arrays_list = nb::cast<std::vector<array>>(args);
2023-11-30 02:30:41 +08:00
for (int i = 0; i < arrays_list.size(); i++) {
std::string arr_name = "arr_" + std::to_string(i);
if (arrays_dict.count(arr_name) > 0) {
throw std::invalid_argument(
"[savez] Cannot use un-named variables and keyword " + arr_name);
}
arrays_dict.insert({arr_name, arrays_list[i]});
}
// Create python ZipFile object depending on compression
nb::module_ zipfile = nb::module_::import_("zipfile");
int compression = nb::cast<int>(
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED"));
2023-11-30 02:30:41 +08:00
char mode = 'w';
ZipFileWrapper zipfile_object(zipfile, file, mode, compression);
// Save each array
for (auto [k, a] : arrays_dict) {
std::string fname = k + ".npy";
auto py_ostream = zipfile_object.open(fname, 'w');
auto writer = std::make_shared<PyFileWriter>(py_ostream);
{
nb::gil_scoped_release nogil;
save(writer, a);
}
2023-11-30 02:30:41 +08:00
}
return;
}
void mlx_save_safetensor_helper(
nb::object file,
nb::dict d,
std::optional<nb::dict> m) {
std::unordered_map<std::string, std::string> metadata_map;
if (m) {
try {
metadata_map =
nb::cast<std::unordered_map<std::string, std::string>>(m.value());
} catch (const nb::cast_error& e) {
throw std::invalid_argument(
"[save_safetensors] Metadata must be a dictionary with string keys and values");
}
} else {
metadata_map = std::unordered_map<std::string, std::string>();
}
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(d);
if (nb::isinstance<nb::str>(file)) {
2024-01-27 14:03:52 +08:00
{
nb::gil_scoped_release nogil;
save_safetensors(nb::cast<std::string>(file), arrays_map, metadata_map);
2024-01-27 14:03:52 +08:00
}
} else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file);
{
nb::gil_scoped_release nogil;
save_safetensors(writer, arrays_map, metadata_map);
}
2024-01-27 14:03:52 +08:00
} else {
throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
}
}
void mlx_save_gguf_helper(
nb::object file,
nb::dict a,
std::optional<nb::dict> m) {
auto arrays_map = nb::cast<std::unordered_map<std::string, array>>(a);
if (nb::isinstance<nb::str>(file)) {
if (m) {
auto metadata_map =
nb::cast<std::unordered_map<std::string, GGUFMetaData>>(m.value());
2024-01-27 14:03:52 +08:00
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map, metadata_map);
2024-01-27 14:03:52 +08:00
}
} else {
2024-01-27 14:03:52 +08:00
{
nb::gil_scoped_release nogil;
save_gguf(nb::cast<std::string>(file), arrays_map);
2024-01-27 14:03:52 +08:00
}
}
2024-01-27 14:03:52 +08:00
} else {
throw std::invalid_argument("[save_gguf] Input must be a string");
}
}