mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
		@@ -160,31 +160,29 @@ class PyFileReader : public io::Reader {
 | 
			
		||||
  py::object tell_func_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s) {
 | 
			
		||||
std::pair<
 | 
			
		||||
    std::unordered_map<std::string, array>,
 | 
			
		||||
    std::unordered_map<std::string, std::string>>
 | 
			
		||||
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
 | 
			
		||||
    return load_safetensors(py::cast<std::string>(file), s);
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
    auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      for (auto& [key, arr] : arr) {
 | 
			
		||||
      for (auto& [key, arr] : std::get<0>(res)) {
 | 
			
		||||
        arr.eval();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return arr;
 | 
			
		||||
    return res;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[load_safetensors] Input must be a file-like object, or string");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<
 | 
			
		||||
    std::unordered_map<std::string, array>,
 | 
			
		||||
    std::unordered_map<std::string, MetaData>>
 | 
			
		||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
 | 
			
		||||
    return load_gguf(py::cast<std::string>(file), s);
 | 
			
		||||
  }
 | 
			
		||||
@@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper(
 | 
			
		||||
    format.emplace(fname.substr(ext + 1));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (return_metadata && format.value() != "gguf") {
 | 
			
		||||
  if (return_metadata && (format.value() == "npy" || format.value() == "npz")) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[load] metadata not supported for format " + format.value());
 | 
			
		||||
  }
 | 
			
		||||
  if (format.value() == "safetensors") {
 | 
			
		||||
    return mlx_load_safetensor_helper(file, s);
 | 
			
		||||
    auto [dict, metadata] = mlx_load_safetensor_helper(file, s);
 | 
			
		||||
    if (return_metadata) {
 | 
			
		||||
      return std::make_pair(dict, metadata);
 | 
			
		||||
    }
 | 
			
		||||
    return dict;
 | 
			
		||||
  } else if (format.value() == "npz") {
 | 
			
		||||
    return mlx_load_npz_helper(file, s);
 | 
			
		||||
  } else if (format.value() == "npy") {
 | 
			
		||||
@@ -444,18 +446,33 @@ void mlx_savez_helper(
 | 
			
		||||
  return;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
 | 
			
		||||
void mlx_save_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    py::dict d,
 | 
			
		||||
    std::optional<py::dict> m) {
 | 
			
		||||
  std::unordered_map<std::string, std::string> metadata_map;
 | 
			
		||||
  if (m) {
 | 
			
		||||
    try {
 | 
			
		||||
      metadata_map =
 | 
			
		||||
          m.value().cast<std::unordered_map<std::string, std::string>>();
 | 
			
		||||
    } catch (const py::cast_error& e) {
 | 
			
		||||
      throw std::invalid_argument(
 | 
			
		||||
          "[save_safetensors] Metadata must be a dictionary with string keys and values");
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    metadata_map = std::unordered_map<std::string, std::string>();
 | 
			
		||||
  }
 | 
			
		||||
  auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
 | 
			
		||||
  if (py::isinstance<py::str>(file)) {
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release nogil;
 | 
			
		||||
      save_safetensors(py::cast<std::string>(file), arrays_map);
 | 
			
		||||
      save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
 | 
			
		||||
    }
 | 
			
		||||
  } else if (is_ostream_object(file)) {
 | 
			
		||||
    auto writer = std::make_shared<PyFileWriter>(file);
 | 
			
		||||
    {
 | 
			
		||||
      py::gil_scoped_release nogil;
 | 
			
		||||
      save_safetensors(writer, arrays_map);
 | 
			
		||||
      save_safetensors(writer, arrays_map, metadata_map);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
@@ -471,7 +488,7 @@ void mlx_save_gguf_helper(
 | 
			
		||||
  if (py::isinstance<py::str>(file)) {
 | 
			
		||||
    if (m) {
 | 
			
		||||
      auto metadata_map =
 | 
			
		||||
          m.value().cast<std::unordered_map<std::string, MetaData>>();
 | 
			
		||||
          m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
 | 
			
		||||
      {
 | 
			
		||||
        py::gil_scoped_release nogil;
 | 
			
		||||
        save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
 | 
			
		||||
 
 | 
			
		||||
@@ -15,19 +15,17 @@ using namespace mlx::core;
 | 
			
		||||
using LoadOutputTypes = std::variant<
 | 
			
		||||
    array,
 | 
			
		||||
    std::unordered_map<std::string, array>,
 | 
			
		||||
    std::pair<
 | 
			
		||||
        std::unordered_map<std::string, array>,
 | 
			
		||||
        std::unordered_map<std::string, MetaData>>>;
 | 
			
		||||
    SafetensorsLoad,
 | 
			
		||||
    GGUFLoad>;
 | 
			
		||||
 | 
			
		||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
 | 
			
		||||
void mlx_save_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s);
 | 
			
		||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
 | 
			
		||||
    py::dict d,
 | 
			
		||||
    std::optional<py::dict> m);
 | 
			
		||||
 | 
			
		||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
 | 
			
		||||
 | 
			
		||||
std::pair<
 | 
			
		||||
    std::unordered_map<std::string, array>,
 | 
			
		||||
    std::unordered_map<std::string, MetaData>>
 | 
			
		||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
 | 
			
		||||
void mlx_save_gguf_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    py::dict d,
 | 
			
		||||
 
 | 
			
		||||
@@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) {
 | 
			
		||||
      &mlx_save_safetensor_helper,
 | 
			
		||||
      "file"_a,
 | 
			
		||||
      "arrays"_a,
 | 
			
		||||
      "metadata"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        save_safetensors(file: str, arrays: Dict[str, array])
 | 
			
		||||
        save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)
 | 
			
		||||
 | 
			
		||||
        Save array(s) to a binary file in ``.safetensors`` format.
 | 
			
		||||
 | 
			
		||||
@@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) {
 | 
			
		||||
        Args:
 | 
			
		||||
            file (file, str): File in which the array is saved.
 | 
			
		||||
            arrays (dict(str, array)): The dictionary of names to arrays to be saved.
 | 
			
		||||
            metadata (dict(str, str), optional): The dictionary of metadata to be saved.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "save_gguf",
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user