mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
GGUF support (#350)
* Initial GGUF support for tensor fields. --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -181,6 +181,16 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
@@ -264,6 +274,8 @@ DictOrArray mlx_load_helper(
|
||||
return mlx_load_npz_helper(file, s);
|
||||
} else if (format.value() == "npy") {
|
||||
return mlx_load_npy_helper(file, s);
|
||||
} else if (format.value() == "gguf") {
|
||||
return mlx_load_gguf_helper(file, s);
|
||||
} else {
|
||||
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||
}
|
||||
@@ -435,3 +447,13 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[save_safetensors] Input must be a string");
|
||||
}
|
||||
|
@@ -19,6 +19,11 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d);
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
|
@@ -3048,7 +3048,9 @@ void init_ops(py::module_& m) {
|
||||
R"pbdoc(
|
||||
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
|
||||
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
|
||||
Load array(s) from a binary file.
|
||||
|
||||
The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and ``.gguf``.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
@@ -3059,6 +3061,12 @@ void init_ops(py::module_& m) {
|
||||
result (array, dict):
|
||||
A single array if loading from a ``.npy`` file or a dict mapping
|
||||
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
|
||||
|
||||
Warning:
|
||||
|
||||
When loading unsupported quantization formats from GGUF, tensors will
|
||||
automatically cast to ``mx.float16``
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save_safetensors",
|
||||
@@ -3070,10 +3078,28 @@ void init_ops(py::module_& m) {
|
||||
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
For more information on the format see https://huggingface.co/docs/safetensors/index.
|
||||
See the `Safetensors documentation <https://huggingface.co/docs/safetensors/index>`_
|
||||
for more information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved>
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save_gguf",
|
||||
&mlx_save_gguf_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
R"pbdoc(
|
||||
save_gguf(file: str, arrays: Dict[str, array])
|
||||
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
See the `GGUF documentation <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for
|
||||
more information on the format.
|
||||
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -3306,7 +3332,7 @@ void init_ops(py::module_& m) {
|
||||
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of
|
||||
``b``. If a list of lists is provided, then sum over the
|
||||
corresponding dimensions of ``a`` and ``b``. (default: 2)
|
||||
|
||||
|
||||
Returns:
|
||||
result (array): The tensor dot product.
|
||||
)pbdoc");
|
||||
|
@@ -90,6 +90,33 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_gguf(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
# TODO: Add support for other dtypes (self.dtypes + ["bfloat16"])
|
||||
supported_dtypes = ["float16", "float32", "int8", "int16", "int32"]
|
||||
for dt in supported_dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.gguf"
|
||||
)
|
||||
save_dict = {
|
||||
"test": mx.random.normal(shape=shape, dtype=getattr(mx, dt))
|
||||
if dt in ["float32", "float16", "bfloat16"]
|
||||
else mx.ones(shape, dtype=getattr(mx, dt))
|
||||
}
|
||||
|
||||
mx.save_gguf(save_file_mlx, save_dict)
|
||||
load_dict = mx.load(save_file_mlx)
|
||||
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
@@ -194,13 +221,24 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
# safetensors only works with row contiguous
|
||||
save_file = os.path.join(self.test_dir, "a.gguf")
|
||||
mx.save_gguf(save_file, {"a": a})
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
# safetensors and gguf only work with row contiguous
|
||||
# make sure col contiguous is handled properly
|
||||
save_file = os.path.join(self.test_dir, "a.safetensors")
|
||||
a = mx.arange(4).reshape(2, 2).T
|
||||
mx.save_safetensors(save_file, {"a": a})
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
save_file = os.path.join(self.test_dir, "a.gguf")
|
||||
mx.save_gguf(save_file, {"a": a})
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user