mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
GGUF: Load and save metadata (#446)
* gguf metadata --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -181,9 +181,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
}
|
||||
@@ -246,9 +247,10 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
"[load_npy] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
@@ -268,6 +270,10 @@ DictOrArray mlx_load_helper(
|
||||
format.emplace(fname.substr(ext + 1));
|
||||
}
|
||||
|
||||
if (return_metadata && format.value() != "gguf") {
|
||||
throw std::invalid_argument(
|
||||
"[load] metadata not supported for format " + format.value());
|
||||
}
|
||||
if (format.value() == "safetensors") {
|
||||
return mlx_load_safetensor_helper(file, s);
|
||||
} else if (format.value() == "npz") {
|
||||
@@ -275,7 +281,12 @@ DictOrArray mlx_load_helper(
|
||||
} else if (format.value() == "npy") {
|
||||
return mlx_load_npy_helper(file, s);
|
||||
} else if (format.value() == "gguf") {
|
||||
return mlx_load_gguf_helper(file, s);
|
||||
auto [weights, metadata] = mlx_load_gguf_helper(file, s);
|
||||
if (return_metadata) {
|
||||
return std::make_pair(weights, metadata);
|
||||
} else {
|
||||
return weights;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||
}
|
||||
@@ -448,10 +459,19 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
"[save_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict a,
|
||||
std::optional<py::dict> m) {
|
||||
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
} else {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user