mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Metadata support for safetensors (#639)
* metadata support for safetensors * aliases making it alittle more readable * addressing comments * python binding tests
This commit is contained in:
@@ -160,31 +160,29 @@ class PyFileReader : public io::Reader {
|
||||
py::object tell_func_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s) {
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(py::cast<std::string>(file), s);
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
{
|
||||
py::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : arr) {
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
arr.eval();
|
||||
}
|
||||
}
|
||||
return arr;
|
||||
return res;
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[load_safetensors] Input must be a file-like object, or string");
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s) {
|
||||
if (py::isinstance<py::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(py::cast<std::string>(file), s);
|
||||
}
|
||||
@@ -274,12 +272,16 @@ LoadOutputTypes mlx_load_helper(
|
||||
format.emplace(fname.substr(ext + 1));
|
||||
}
|
||||
|
||||
if (return_metadata && format.value() != "gguf") {
|
||||
if (return_metadata && (format.value() == "npy" || format.value() == "npz")) {
|
||||
throw std::invalid_argument(
|
||||
"[load] metadata not supported for format " + format.value());
|
||||
}
|
||||
if (format.value() == "safetensors") {
|
||||
return mlx_load_safetensor_helper(file, s);
|
||||
auto [dict, metadata] = mlx_load_safetensor_helper(file, s);
|
||||
if (return_metadata) {
|
||||
return std::make_pair(dict, metadata);
|
||||
}
|
||||
return dict;
|
||||
} else if (format.value() == "npz") {
|
||||
return mlx_load_npz_helper(file, s);
|
||||
} else if (format.value() == "npy") {
|
||||
@@ -444,18 +446,33 @@ void mlx_savez_helper(
|
||||
return;
|
||||
}
|
||||
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d) {
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
std::optional<py::dict> m) {
|
||||
std::unordered_map<std::string, std::string> metadata_map;
|
||||
if (m) {
|
||||
try {
|
||||
metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, std::string>>();
|
||||
} catch (const py::cast_error& e) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] Metadata must be a dictionary with string keys and values");
|
||||
}
|
||||
} else {
|
||||
metadata_map = std::unordered_map<std::string, std::string>();
|
||||
}
|
||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map);
|
||||
save_safetensors(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
}
|
||||
} else if (is_ostream_object(file)) {
|
||||
auto writer = std::make_shared<PyFileWriter>(file);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_safetensors(writer, arrays_map);
|
||||
save_safetensors(writer, arrays_map, metadata_map);
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -471,7 +488,7 @@ void mlx_save_gguf_helper(
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
if (m) {
|
||||
auto metadata_map =
|
||||
m.value().cast<std::unordered_map<std::string, MetaData>>();
|
||||
m.value().cast<std::unordered_map<std::string, GGUFMetaData>>();
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
save_gguf(py::cast<std::string>(file), arrays_map, metadata_map);
|
||||
|
@@ -15,19 +15,17 @@ using namespace mlx::core;
|
||||
using LoadOutputTypes = std::variant<
|
||||
array,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>>;
|
||||
SafetensorsLoad,
|
||||
GGUFLoad>;
|
||||
|
||||
std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
||||
SafetensorsLoad mlx_load_safetensor_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(
|
||||
py::object file,
|
||||
StreamOrDevice s);
|
||||
void mlx_save_safetensor_helper(py::object file, py::dict d);
|
||||
py::dict d,
|
||||
std::optional<py::dict> m);
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
mlx_load_gguf_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_gguf_helper(
|
||||
py::object file,
|
||||
py::dict d,
|
||||
|
@@ -3214,8 +3214,9 @@ void init_ops(py::module_& m) {
|
||||
&mlx_save_safetensor_helper,
|
||||
"file"_a,
|
||||
"arrays"_a,
|
||||
"metadata"_a = none,
|
||||
R"pbdoc(
|
||||
save_safetensors(file: str, arrays: Dict[str, array])
|
||||
save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)
|
||||
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@@ -3225,6 +3226,7 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
file (file, str): File in which the array is saved.
|
||||
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||
metadata (dict(str, str), optional): The dictionary of metadata to be saved.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save_gguf",
|
||||
|
@@ -66,6 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
for dt in self.dtypes + ["bfloat16"]:
|
||||
with self.subTest(dtype=dt):
|
||||
|
Reference in New Issue
Block a user