mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
change name to safetensors
This commit is contained in:
parent
227ef82784
commit
313f6bd9b1
@ -83,6 +83,7 @@ Operations
|
|||||||
save
|
save
|
||||||
savez
|
savez
|
||||||
savez_compressed
|
savez_compressed
|
||||||
|
save_safetensors
|
||||||
sigmoid
|
sigmoid
|
||||||
sign
|
sign
|
||||||
sin
|
sin
|
||||||
|
@ -69,21 +69,21 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Load array from reader in safetensor format */
|
/** Load array from reader in safetensor format */
|
||||||
std::unordered_map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Open and check file
|
// Open and check file
|
||||||
if (!in_stream->good() || !in_stream->is_open()) {
|
if (!in_stream->good() || !in_stream->is_open()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Failed to open " + in_stream->label());
|
"[load_safetensors] Failed to open " + in_stream->label());
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t jsonHeaderLength = 0;
|
uint64_t jsonHeaderLength = 0;
|
||||||
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
|
||||||
if (jsonHeaderLength <= 0) {
|
if (jsonHeaderLength <= 0) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Invalid json header length " + in_stream->label());
|
"[load_safetensors] Invalid json header length " + in_stream->label());
|
||||||
}
|
}
|
||||||
// Load the json metadata
|
// Load the json metadata
|
||||||
char rawJson[jsonHeaderLength];
|
char rawJson[jsonHeaderLength];
|
||||||
@ -92,7 +92,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
// Should always be an object on the top-level
|
// Should always be an object on the top-level
|
||||||
if (!metadata.is_object()) {
|
if (!metadata.is_object()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[load_safetensor] Invalid json metadata " + in_stream->label());
|
"[load_safetensors] Invalid json metadata " + in_stream->label());
|
||||||
}
|
}
|
||||||
size_t offset = jsonHeaderLength + 8;
|
size_t offset = jsonHeaderLength + 8;
|
||||||
// Load the arrays using metadata
|
// Load the arrays using metadata
|
||||||
@ -117,14 +117,14 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
return load_safetensor(std::make_shared<io::FileReader>(file), s);
|
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Save array to out stream in .npy format */
|
/** Save array to out stream in .npy format */
|
||||||
void save_safetensor(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a,
|
||||||
std::optional<bool> retain_graph_) {
|
std::optional<bool> retain_graph_) {
|
||||||
@ -132,7 +132,7 @@ void save_safetensor(
|
|||||||
// Check file
|
// Check file
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[save_safetensor] Failed to open " + out_stream->label());
|
"[save_safetensors] Failed to open " + out_stream->label());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
@ -146,12 +146,12 @@ void save_safetensor(
|
|||||||
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
||||||
if (arr.nbytes() == 0) {
|
if (arr.nbytes() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensor] cannot serialize an empty array key: " + key);
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!arr.flags().contiguous) {
|
if (!arr.flags().contiguous) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensor] cannot serialize a non-contiguous array key: " +
|
"[save_safetensors] cannot serialize a non-contiguous array key: " +
|
||||||
key);
|
key);
|
||||||
}
|
}
|
||||||
json child;
|
json child;
|
||||||
@ -171,7 +171,7 @@ void save_safetensor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void save_safetensor(
|
void save_safetensors(
|
||||||
const std::string& file_,
|
const std::string& file_,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a,
|
||||||
std::optional<bool> retain_graph) {
|
std::optional<bool> retain_graph) {
|
||||||
@ -184,7 +184,7 @@ void save_safetensor(
|
|||||||
file += ".safetensors";
|
file += ".safetensors";
|
||||||
|
|
||||||
// Serialize array
|
// Serialize array
|
||||||
save_safetensor(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
10
mlx/ops.h
10
mlx/ops.h
@ -1057,19 +1057,19 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Load array map from .safetensor file format */
|
/** Load array map from .safetensors file format */
|
||||||
std::unordered_map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
std::shared_ptr<io::Reader> in_stream,
|
std::shared_ptr<io::Reader> in_stream,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
std::unordered_map<std::string, array> load_safetensor(
|
std::unordered_map<std::string, array> load_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
void save_safetensor(
|
void save_safetensors(
|
||||||
std::shared_ptr<io::Writer> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
void save_safetensor(
|
void save_safetensors(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
std::optional<bool> retain_graph = std::nullopt);
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
|
@ -164,10 +164,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
|||||||
py::object file,
|
py::object file,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
|
||||||
return {load_safetensor(py::cast<std::string>(file), s)};
|
return {load_safetensors(py::cast<std::string>(file), s)};
|
||||||
} else if (is_istream_object(file)) {
|
} else if (is_istream_object(file)) {
|
||||||
// If we don't own the stream and it was passed to us, eval immediately
|
// If we don't own the stream and it was passed to us, eval immediately
|
||||||
auto arr = load_safetensor(std::make_shared<PyFileReader>(file), s);
|
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
for (auto& [key, arr] : arr) {
|
for (auto& [key, arr] : arr) {
|
||||||
@ -178,7 +178,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[load_safetensor] Input must be a file-like object, or string");
|
"[load_safetensors] Input must be a file-like object, or string");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
std::unordered_map<std::string, array> mlx_load_npz_helper(
|
||||||
@ -427,18 +427,18 @@ void mlx_save_safetensor_helper(
|
|||||||
std::optional<bool> retain_graph) {
|
std::optional<bool> retain_graph) {
|
||||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph);
|
save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||||
return;
|
return;
|
||||||
} else if (is_ostream_object(file)) {
|
} else if (is_ostream_object(file)) {
|
||||||
auto writer = std::make_shared<PyFileWriter>(file);
|
auto writer = std::make_shared<PyFileWriter>(file);
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil;
|
py::gil_scoped_release gil;
|
||||||
save_safetensor(writer, arrays_map, retain_graph);
|
save_safetensors(writer, arrays_map, retain_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensor] Input must be a file-like object, or string");
|
"[save_safetensors] Input must be a file-like object, or string");
|
||||||
}
|
}
|
@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
file (str): File to which the array is saved
|
file (str): File to which the array is saved
|
||||||
arr (array): Array to be saved.
|
arr (array): Array to be saved.
|
||||||
retain_graph (bool, optional): Optional argument to retain graph
|
retain_graph (bool, optional): Whether or not to retain the graph
|
||||||
during array evaluation before saving. If not provided the graph
|
during array evaluation. If left unspecified the graph is retained
|
||||||
is retained if we are during a function transformation. Default:
|
only if saving is done in a function transformation. Default: ``None``
|
||||||
None
|
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"savez",
|
"savez",
|
||||||
@ -2941,32 +2939,36 @@ void init_ops(py::module_& m) {
|
|||||||
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
|
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved.
|
||||||
format (str, optional): Format of the file. If ``None``, the format
|
format (str, optional): Format of the file. If ``None``, the format
|
||||||
is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. (default: ``None``)
|
is inferred from the file extension. Supported formats: ``npy``,
|
||||||
|
``npz``, and ``safetensors``. Default: ``None``.
|
||||||
Returns:
|
Returns:
|
||||||
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` or ``.safetensors`` file
|
result (array, dict):
|
||||||
|
A single array if loading from a ``.npy`` file or a dict mapping
|
||||||
|
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"save_safetensor",
|
"save_safetensors",
|
||||||
&mlx_save_safetensor_helper,
|
&mlx_save_safetensor_helper,
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"d"_a,
|
"arrays"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"retain_graph"_a = std::nullopt,
|
"retain_graph"_a = std::nullopt,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save_safetensor(file: str, d: Dict[str, array], /, retain_graph: Optional[bool] = None, *)
|
save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
|
||||||
|
|
||||||
Save array(s) to a binary file in ``.safetensors`` format.
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
|
||||||
For more information on the format see https://huggingface.co/docs/safetensors/index.
|
For more information on the format see https://huggingface.co/docs/safetensors/index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved>
|
||||||
d (Dict[str, array]): The dict mapping name to array to be saved
|
arrays (dict(str, array)): The dictionary of names to arrays to be saved.
|
||||||
retain_graph(Optional[bool]): Optional argument to retain graph
|
retain_graph (bool, optional): Whether or not to retain the graph
|
||||||
during array evaluation before saving. If not provided the graph
|
during array evaluation. If left unspecified the graph is retained
|
||||||
is retained if we are during a function transformation. Default: None
|
only if saving is done in a function transformation. Default: ``None``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
|
@ -64,7 +64,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||||
|
|
||||||
def test_save_and_load_safetensor(self):
|
def test_save_and_load_safetensors(self):
|
||||||
if not os.path.isdir(self.test_dir):
|
if not os.path.isdir(self.test_dir):
|
||||||
os.mkdir(self.test_dir)
|
os.mkdir(self.test_dir)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with open(save_file_mlx, "wb") as f:
|
with open(save_file_mlx, "wb") as f:
|
||||||
mx.save_safetensor(f, save_dict)
|
mx.save_safetensors(f, save_dict)
|
||||||
with open(save_file_mlx, "rb") as f:
|
with open(save_file_mlx, "rb") as f:
|
||||||
load_dict = mx.load(f)
|
load_dict = mx.load(f)
|
||||||
|
|
||||||
|
@ -14,13 +14,13 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
return std::filesystem::temp_directory_path().append(name);
|
return std::filesystem::temp_directory_path().append(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test save_safetensor") {
|
TEST_CASE("test save_safetensors") {
|
||||||
std::string file_path = get_temp_file("test_arr.safetensors");
|
std::string file_path = get_temp_file("test_arr.safetensors");
|
||||||
auto map = std::unordered_map<std::string, array>();
|
auto map = std::unordered_map<std::string, array>();
|
||||||
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
|
||||||
map.insert({"test2", ones({2, 2})});
|
map.insert({"test2", ones({2, 2})});
|
||||||
save_safetensor(file_path, map);
|
save_safetensors(file_path, map);
|
||||||
auto safeDict = load_safetensor(file_path);
|
auto safeDict = load_safetensors(file_path);
|
||||||
CHECK_EQ(safeDict.size(), 2);
|
CHECK_EQ(safeDict.size(), 2);
|
||||||
CHECK_EQ(safeDict.count("test"), 1);
|
CHECK_EQ(safeDict.count("test"), 1);
|
||||||
CHECK_EQ(safeDict.count("test2"), 1);
|
CHECK_EQ(safeDict.count("test2"), 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user