mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	@@ -6,12 +6,11 @@
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mlx/load.h"
 | 
			
		||||
#include "mlx/io/load.h"
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
#include "python/src/load.h"
 | 
			
		||||
@@ -161,40 +160,68 @@ class PyFileReader : public io::Reader {
 | 
			
		||||
  py::object tell_func_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  py::module_ zipfile = py::module_::import("zipfile");
 | 
			
		||||
 | 
			
		||||
  // Assume .npz file if it is zipped
 | 
			
		||||
  if (is_zip_file(zipfile, file)) {
 | 
			
		||||
    // Output dictionary filename in zip -> loaded array
 | 
			
		||||
    std::unordered_map<std::string, array> array_dict;
 | 
			
		||||
 | 
			
		||||
    // Create python ZipFile object
 | 
			
		||||
    ZipFileWrapper zipfile_object(zipfile, file);
 | 
			
		||||
    for (const std::string& st : zipfile_object.namelist()) {
 | 
			
		||||
      // Open zip file as a python file stream
 | 
			
		||||
      py::object sub_file = zipfile_object.open(st);
 | 
			
		||||
 | 
			
		||||
      // Create array from python fille stream
 | 
			
		||||
      auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
 | 
			
		||||
 | 
			
		||||
      // Remove .npy from file if it is there
 | 
			
		||||
      auto key = st;
 | 
			
		||||
      if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
 | 
			
		||||
        key = st.substr(0, st.length() - 4);
 | 
			
		||||
 | 
			
		||||
      // Add array to dict
 | 
			
		||||
      array_dict.insert({key, arr});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
 | 
			
		||||
    return {load_safetensors(py::cast<std::string>(file), s)};
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    for (auto& [key, arr] : array_dict) {
 | 
			
		||||
    auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      arr.eval();
 | 
			
		||||
      for (auto& [key, arr] : arr) {
 | 
			
		||||
        arr.eval();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return {arr};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    return {array_dict};
 | 
			
		||||
  } else if (py::isinstance<py::str>(file)) { // Assume .npy file path string
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[load_safetensors] Input must be a file-like object, or string");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s) {
 | 
			
		||||
  py::module_ zipfile = py::module_::import("zipfile");
 | 
			
		||||
  if (!is_zip_file(zipfile, file)) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[load_npz] Input must be a zip file or a file-like object that can be "
 | 
			
		||||
        "opened with zipfile.ZipFile");
 | 
			
		||||
  }
 | 
			
		||||
  // Output dictionary filename in zip -> loaded array
 | 
			
		||||
  std::unordered_map<std::string, array> array_dict;
 | 
			
		||||
 | 
			
		||||
  // Create python ZipFile object
 | 
			
		||||
  ZipFileWrapper zipfile_object(zipfile, file);
 | 
			
		||||
  for (const std::string& st : zipfile_object.namelist()) {
 | 
			
		||||
    // Open zip file as a python file stream
 | 
			
		||||
    py::object sub_file = zipfile_object.open(st);
 | 
			
		||||
 | 
			
		||||
    // Create array from python fille stream
 | 
			
		||||
    auto arr = load(std::make_shared<PyFileReader>(sub_file), s);
 | 
			
		||||
 | 
			
		||||
    // Remove .npy from file if it is there
 | 
			
		||||
    auto key = st;
 | 
			
		||||
    if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy")
 | 
			
		||||
      key = st.substr(0, st.length() - 4);
 | 
			
		||||
 | 
			
		||||
    // Add array to dict
 | 
			
		||||
    array_dict.insert({key, arr});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
  for (auto& [key, arr] : array_dict) {
 | 
			
		||||
    py::gil_scoped_release gil;
 | 
			
		||||
    arr.eval();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return {array_dict};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .npy file path string
 | 
			
		||||
    return {load(py::cast<std::string>(file), s)};
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
@@ -205,9 +232,41 @@ DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
    }
 | 
			
		||||
    return {arr};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[load] Input must be a file-like object, string, or pathlib.Path");
 | 
			
		||||
      "[load_npy] Input must be a file-like object, or string");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DictOrArray mlx_load_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    std::optional<std::string> format,
 | 
			
		||||
    StreamOrDevice s) {
 | 
			
		||||
  if (!format.has_value()) {
 | 
			
		||||
    std::string fname;
 | 
			
		||||
    if (py::isinstance<py::str>(file)) {
 | 
			
		||||
      fname = py::cast<std::string>(file);
 | 
			
		||||
    } else if (is_istream_object(file)) {
 | 
			
		||||
      fname = file.attr("name").cast<std::string>();
 | 
			
		||||
    } else {
 | 
			
		||||
      throw std::invalid_argument(
 | 
			
		||||
          "[load] Input must be a file-like object, or string");
 | 
			
		||||
    }
 | 
			
		||||
    size_t ext = fname.find_last_of('.');
 | 
			
		||||
    if (ext == std::string::npos) {
 | 
			
		||||
      throw std::invalid_argument(
 | 
			
		||||
          "[load] Could not infer file format from extension");
 | 
			
		||||
    }
 | 
			
		||||
    format.emplace(fname.substr(ext + 1));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (format.value() == "safetensors") {
 | 
			
		||||
    return mlx_load_safetensor_helper(file, s);
 | 
			
		||||
  } else if (format.value() == "npz") {
 | 
			
		||||
    return mlx_load_npz_helper(file, s);
 | 
			
		||||
  } else if (format.value() == "npy") {
 | 
			
		||||
    return mlx_load_npy_helper(file, s);
 | 
			
		||||
  } else {
 | 
			
		||||
    throw std::invalid_argument("[load] Unknown file format " + format.value());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
@@ -305,7 +364,7 @@ void mlx_save_helper(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[save] Input must be a file-like object, string, or pathlib.Path");
 | 
			
		||||
      "[save] Input must be a file-like object, or string");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mlx_savez_helper(
 | 
			
		||||
@@ -361,3 +420,25 @@ void mlx_savez_helper(
 | 
			
		||||
 | 
			
		||||
  return;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mlx_save_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    py::dict d,
 | 
			
		||||
    std::optional<bool> retain_graph) {
 | 
			
		||||
  auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
 | 
			
		||||
  if (py::isinstance<py::str>(file)) {
 | 
			
		||||
    save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
 | 
			
		||||
    return;
 | 
			
		||||
  } else if (is_ostream_object(file)) {
 | 
			
		||||
    auto writer = std::make_shared<PyFileWriter>(file);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      save_safetensors(writer, arrays_map, retain_graph);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[save_safetensors] Input must be a file-like object, or string");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,8 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
@@ -12,7 +14,18 @@ using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
 | 
			
		||||
 | 
			
		||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
 | 
			
		||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s);
 | 
			
		||||
void mlx_save_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    py::dict d,
 | 
			
		||||
    std::optional<bool> retain_graph = std::nullopt);
 | 
			
		||||
 | 
			
		||||
DictOrArray mlx_load_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    std::optional<std::string> format,
 | 
			
		||||
    StreamOrDevice s);
 | 
			
		||||
void mlx_save_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    array a,
 | 
			
		||||
 
 | 
			
		||||
@@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
 | 
			
		||||
        Args:
 | 
			
		||||
            file (str): File to which the array is saved
 | 
			
		||||
            arr (array): Array to be saved.
 | 
			
		||||
            retain_graph (bool, optional): Optional argument to retain graph
 | 
			
		||||
              during array evaluation before saving. If not provided the graph
 | 
			
		||||
              is retained if we are during a function transformation. Default:
 | 
			
		||||
              None
 | 
			
		||||
 | 
			
		||||
            retain_graph (bool, optional): Whether or not to retain the graph
 | 
			
		||||
              during array evaluation. If left unspecified the graph is retained
 | 
			
		||||
              only if saving is done in a function transformation. Default: ``None``
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "savez",
 | 
			
		||||
@@ -2932,18 +2930,45 @@ void init_ops(py::module_& m) {
 | 
			
		||||
      &mlx_load_helper,
 | 
			
		||||
      "file"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "format"_a = none,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        load(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
 | 
			
		||||
        load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
 | 
			
		||||
 | 
			
		||||
        Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
 | 
			
		||||
        Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            file (file, str): File in which the array is saved
 | 
			
		||||
 | 
			
		||||
            file (file, str): File in which the array is saved.
 | 
			
		||||
            format (str, optional): Format of the file. If ``None``, the format
 | 
			
		||||
              is inferred from the file extension. Supported formats: ``npy``,
 | 
			
		||||
              ``npz``, and ``safetensors``. Default: ``None``.
 | 
			
		||||
        Returns:
 | 
			
		||||
            result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
 | 
			
		||||
            result (array, dict):
 | 
			
		||||
                A single array if loading from a ``.npy`` file or a dict mapping
 | 
			
		||||
                names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "save_safetensors",
 | 
			
		||||
      &mlx_save_safetensor_helper,
 | 
			
		||||
      "file"_a,
 | 
			
		||||
      "arrays"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "retain_graph"_a = std::nullopt,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
 | 
			
		||||
 | 
			
		||||
        Save array(s) to a binary file in ``.safetensors`` format.
 | 
			
		||||
 | 
			
		||||
        For more information on the format see https://huggingface.co/docs/safetensors/index.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            file (file, str): File in which the array is saved>
 | 
			
		||||
            arrays (dict(str, array)): The dictionary of names to arrays to be saved.
 | 
			
		||||
            retain_graph (bool, optional): Whether or not to retain the graph
 | 
			
		||||
              during array evaluation. If left unspecified the graph is retained
 | 
			
		||||
              only if saving is done in a function transformation. Default: ``None``.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "where",
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user