mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
GGUF: Load and save metadata (#446)
* gguf metadata --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -181,9 +181,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
}
|
||||
@@ -246,9 +247,10 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
|
||||
"[load_npy] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s) {
|
||||
if (!format.has_value()) {
|
||||
std::string fname;
|
||||
@@ -268,6 +270,10 @@ DictOrArray mlx_load_helper(
|
||||
format.emplace(fname.substr(ext + 1));
|
||||
}
|
||||
|
||||
if (return_metadata && format.value() != "gguf") {
|
||||
throw std::invalid_argument(
|
||||
"[load] metadata not supported for format " + format.value());
|
||||
}
|
||||
if (format.value() == "safetensors") {
|
||||
return mlx_load_safetensor_helper(file, s);
|
||||
} else if (format.value() == "npz") {
|
||||
@@ -275,7 +281,12 @@ DictOrArray mlx_load_helper(
|
||||
} else if (format.value() == "npy") {
|
||||
return mlx_load_npy_helper(file, s);
|
||||
} else if (format.value() == "gguf") {
|
||||
return mlx_load_gguf_helper(file, s);
|
||||
auto [weights, metadata] = mlx_load_gguf_helper(file, s);
|
||||
if (return_metadata) {
|
||||
return std::make_pair(weights, metadata);
|
||||
} else {
|
||||
return weights;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("[load] Unknown file format " + format.value());
|
||||
}
|
||||
@@ -448,10 +459,19 @@ void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
"[save_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d) {
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict a,
|
||||
std::optional<py::dict> m) {
|
||||
auto arrays_map = a.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
} else {
|
||||
save_gguf(py::cast<std::string>(file), arrays_map);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@@ -7,26 +7,36 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/io.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
using LoadOutputTypes = std::variant<
|
||||
array,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>>;
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_gguf_helper(
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(py::object file, py::dict d);
|
||||
py::dict d,
|
||||
std::optional<py::dict> m);
|
||||
|
||||
DictOrArray mlx_load_helper(
|
||||
LoadOutputTypes mlx_load_helper(
|
||||
py::object file,
|
||||
std::optional<std::string> format,
|
||||
bool return_metadata,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_helper(py::object file, array a);
|
||||
void mlx_savez_helper(
|
||||
|
@@ -1867,11 +1867,11 @@ void init_ops(py::module_& m) {
|
||||
isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are positive infinity.
|
||||
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are positive infinity.
|
||||
)pbdoc");
|
||||
@@ -1886,11 +1886,11 @@ void init_ops(py::module_& m) {
|
||||
isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Return a boolean array indicating which elements are negative infinity.
|
||||
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Union[None, Stream, Device]): Optional stream or device.
|
||||
|
||||
|
||||
Returns:
|
||||
array: The boolean array indicating which elements are negative infinity.
|
||||
)pbdoc");
|
||||
@@ -3117,10 +3117,11 @@ void init_ops(py::module_& m) {
|
||||
"file"_a,
|
||||
py::pos_only(),
|
||||
"format"_a = none,
|
||||
"return_metadata"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@@ -3131,10 +3132,15 @@ void init_ops(py::module_& m) {
|
||||
format (str, optional): Format of the file. If ``None``, the format
|
||||
is inferred from the file extension. Supported formats: ``npy``,
|
||||
``npz``, and ``safetensors``. Default: ``None``.
|
||||
return_metadata (bool, optional): Load the metadata for formats which
|
||||
support matadata. The metadata will be returned as an additional
|
||||
dictionary.
|
||||
Returns:
|
||||
result (array, dict):
|
||||
A single array if loading from a ``.npy`` file or a dict mapping
|
||||
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
|
||||
If ``return_metadata` is ``True`` an additional dictionary of metadata
|
||||
will be returned.
|
||||
|
||||
Warning:
|
||||
|
||||
@@ -3164,8 +3170,9 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_gguf_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
"metadata"_a = none,
|
||||
R"pbdoc(
|
||||
save_gguf(file: str, arrays: Dict[str, array])
|
||||
save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])
|
||||
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@@ -3175,6 +3182,9 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
metadata (dict(str, Union[array, str, list(str)])): The dictionary of
|
||||
metadata to be saved. The values can be a scalar or 1D obj:`array`,
|
||||
a :obj:`str`, or a :obj:`list` of :obj:`str`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"where",
|
||||
@@ -3499,7 +3509,7 @@ void init_ops(py::module_& m) {
|
||||
c (array): Input array or scalar.
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
alpha (float, optional): Scaling factor for the
|
||||
alpha (float, optional): Scaling factor for the
|
||||
matrix product of ``a`` and ``b`` (default: ``1``)
|
||||
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
|
||||
|
||||
|
@@ -576,8 +576,8 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertListEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-5).item())
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
@@ -117,6 +117,115 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(load_dict["test"], save_dict["test"])
|
||||
)
|
||||
|
||||
def test_save_and_load_gguf_metadata_basic(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
metadata = {}
|
||||
|
||||
# Empty works
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
# Loads without the metadata
|
||||
load_dict = mx.load(save_file_mlx)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
|
||||
# Loads empty metadata
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 0)
|
||||
|
||||
# Loads string metadata
|
||||
metadata = {"meta": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
load_dict, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertTrue("test" in load_dict)
|
||||
self.assertTrue(mx.array_equal(load_dict["test"], save_dict["test"]))
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta"], "data")
|
||||
|
||||
def test_save_and_load_gguf_metadata_arrays(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test scalars and one dimensional arrays
|
||||
for t in [
|
||||
mx.uint8,
|
||||
mx.int8,
|
||||
mx.uint16,
|
||||
mx.int16,
|
||||
mx.uint32,
|
||||
mx.int32,
|
||||
mx.uint64,
|
||||
mx.int64,
|
||||
mx.float32,
|
||||
]:
|
||||
for shape in [(), (2,)]:
|
||||
arr = mx.random.uniform(shape=shape).astype(t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertTrue("meta" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta"], arr))
|
||||
self.assertEqual(meta_load_dict["meta"].dtype, arr.dtype)
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.complex64]:
|
||||
with self.assertRaises(ValueError):
|
||||
arr = mx.array(1, t)
|
||||
metadata = {"meta": arr}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
|
||||
def test_save_and_load_gguf_metadata_mixed(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_gguf_with_metadata.gguf")
|
||||
save_dict = {"test": mx.ones((4, 4), dtype=mx.int32)}
|
||||
|
||||
# Test string and array
|
||||
arr = mx.array(1.5)
|
||||
metadata = {"meta1": arr, "meta2": "data"}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 2)
|
||||
self.assertTrue("meta1" in meta_load_dict)
|
||||
self.assertTrue(mx.array_equal(meta_load_dict["meta1"], arr))
|
||||
self.assertEqual(meta_load_dict["meta1"].dtype, arr.dtype)
|
||||
self.assertTrue("meta2" in meta_load_dict)
|
||||
self.assertEqual(meta_load_dict["meta2"], "data")
|
||||
|
||||
# Test list of strings
|
||||
metadata = {"meta": ["data1", "data2", "data345"]}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 1)
|
||||
self.assertEqual(meta_load_dict["meta"], metadata["meta"])
|
||||
|
||||
# Test a combination of stuff
|
||||
metadata = {
|
||||
"meta1": ["data1", "data2", "data345"],
|
||||
"meta2": mx.array([1, 2, 3, 4]),
|
||||
"meta3": "data",
|
||||
"meta4": mx.array(1.5),
|
||||
}
|
||||
mx.save_gguf(save_file_mlx, save_dict, metadata)
|
||||
_, meta_load_dict = mx.load(save_file_mlx, return_metadata=True)
|
||||
self.assertEqual(len(meta_load_dict), 4)
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, mx.array):
|
||||
self.assertTrue(mx.array_equal(meta_load_dict[k], v))
|
||||
else:
|
||||
self.assertEqual(meta_load_dict[k], v)
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
Reference in New Issue
Block a user