GGUF support (#350)

* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Juarez Bochi
2024-01-10 16:22:48 -05:00
committed by GitHub
parent e3e933c6bc
commit b7f905787e
12 changed files with 362 additions and 55 deletions

View File

@@ -181,6 +181,16 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
"[load_safetensors] Input must be a file-like object, or string");
}
std::unordered_map<std::string, array> mlx_load_gguf_helper(
py::object file,
StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
StreamOrDevice s) {
@@ -264,6 +274,8 @@ DictOrArray mlx_load_helper(
return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else if (format.value() == "gguf") {
return mlx_load_gguf_helper(file, s);
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
@@ -435,3 +447,13 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
}
void mlx_save_gguf_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_gguf(py::cast<std::string>(file), arrays_map);
return;
}
throw std::invalid_argument("[save_safetensors] Input must be a string");
}