GGUF support (#350)

* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Juarez Bochi
2024-01-10 16:22:48 -05:00
committed by GitHub
parent e3e933c6bc
commit b7f905787e
12 changed files with 362 additions and 55 deletions

View File

@@ -181,6 +181,16 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
"[load_safetensors] Input must be a file-like object, or string");
}
std::unordered_map<std::string, array> mlx_load_gguf_helper(
py::object file,
StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
return load_gguf(py::cast<std::string>(file), s);
}
throw std::invalid_argument("[load_gguf] Input must be a string");
}
std::unordered_map<std::string, array> mlx_load_npz_helper(
py::object file,
StreamOrDevice s) {
@@ -264,6 +274,8 @@ DictOrArray mlx_load_helper(
return mlx_load_npz_helper(file, s);
} else if (format.value() == "npy") {
return mlx_load_npy_helper(file, s);
} else if (format.value() == "gguf") {
return mlx_load_gguf_helper(file, s);
} else {
throw std::invalid_argument("[load] Unknown file format " + format.value());
}
@@ -435,3 +447,13 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
throw std::invalid_argument(
"[save_safetensors] Input must be a file-like object, or string");
}
void mlx_save_gguf_helper(py::object file, py::dict d) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) {
save_gguf(py::cast<std::string>(file), arrays_map);
return;
}
throw std::invalid_argument("[save_safetensors] Input must be a string");
}

View File

@@ -19,6 +19,11 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
StreamOrDevice s);
void mlx_save_safetensor_helper(py::object file, py::dict d);
std::unordered_map<std::string, array> mlx_load_gguf_helper(
py::object file,
StreamOrDevice s);
void mlx_save_gguf_helper(py::object file, py::dict d);
DictOrArray mlx_load_helper(
py::object file,
std::optional<std::string> format,

View File

@@ -3048,7 +3048,9 @@ void init_ops(py::module_& m) {
R"pbdoc(
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
Load array(s) from a binary file.
The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and ``.gguf``.
Args:
file (file, str): File in which the array is saved.
@@ -3059,6 +3061,12 @@ void init_ops(py::module_& m) {
result (array, dict):
A single array if loading from a ``.npy`` file or a dict mapping
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
Warning:
When loading unsupported quantization formats from GGUF, tensors will
automatically cast to ``mx.float16``
)pbdoc");
m.def(
"save_safetensors",
@@ -3070,10 +3078,28 @@ void init_ops(py::module_& m) {
Save array(s) to a binary file in ``.safetensors`` format.
For more information on the format see https://huggingface.co/docs/safetensors/index.
See the `Safetensors documentation <https://huggingface.co/docs/safetensors/index>`_
for more information on the format.
Args:
file (file, str): File in which the array is saved>
file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
)pbdoc");
m.def(
"save_gguf",
&mlx_save_gguf_helper,
"file"_a,
"arrays"_a,
R"pbdoc(
save_gguf(file: str, arrays: Dict[str, array])
Save array(s) to a binary file in ``.gguf`` format.
See the `GGUF documentation <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for
more information on the format.
Args:
file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
)pbdoc");
m.def(
@@ -3306,7 +3332,7 @@ void init_ops(py::module_& m) {
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. (default: 2)
Returns:
result (array): The tensor dot product.
)pbdoc");