GGUF: Load and save metadata (#446)

* gguf metadata
---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Juarez Bochi
2024-01-19 23:06:05 +01:00
committed by GitHub
parent 6589c869d6
commit ddf50113c5
11 changed files with 668 additions and 94 deletions

View File

@@ -181,9 +181,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
"[load_safetensors] Input must be a file-like object, or string");
}
std::unordered_map<std::string, array> mlx_load_gguf_helper(
py::object file,
StreamOrDevice s) {
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s);
}
@@ -246,9 +247,10 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
"[load_npy] Input must be a file-like object, or string");
}
DictOrArray mlx_load_helper(
LoadOutputTypes mlx_load_helper(
py::object file,
std::optional<std::string> format,
bool return_metadata,
StreamOrDevice s) {
if (!format.has_value()) {
std::string fname;
@@ -268,6 +270,10 @@ DictOrArray mlx_load_helper(
format.emplace(fname.substr(ext + 1));
}
if (return_metadata && format.value() != "gguf") {
throw std::invalid_argument(
"[load] metadata not supported for format " + format.value());
}
if (format.value() == "safetensors") {
return mlx_load_safetensor_helper(file, s);
} else if (format.value() == "npz") {
@@ -275,7 +281,12 @@ DictOrArray mlx_load_helper(
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else if (format.value() == "gguf") {
return mlx_load_gguf_helper(file, s);
auto [weights, metadata] = mlx_load_gguf_helper(file, s);
if (return_metadata) {
return std::make_pair(weights, metadata);
} else {
return weights;
}
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
@@ -448,10 +459,19 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
"[save_safetensors] Input must be a file-like object, or string");
}
void mlx_save_gguf_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
void mlx_save_gguf_helper(
py::object file,
py::dict a,
std::optional<py::dict> m) {
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_gguf(py::cast<std::string>(file), arrays_map);
if (m) {
auto metadata_map =
m.value().cast<std::unordered_map<std::string, MetaData>>();
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
} else {
save_gguf(py::cast<std::string>(file), arrays_map);
}
return;
}