change name to safetensors

This commit is contained in:
Awni Hannun 2023-12-22 21:06:49 -08:00
parent 227ef82784
commit 313f6bd9b1
7 changed files with 50 additions and 47 deletions

View File

@ -83,6 +83,7 @@ Operations
save save
savez savez
savez_compressed savez_compressed
save_safetensors
sigmoid sigmoid
sign sign
sin sin

View File

@ -69,21 +69,21 @@ Dtype dtype_from_safetensor_str(std::string str) {
} }
/** Load array from reader in safetensor format */ /** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensor( std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) { StreamOrDevice s) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Open and check file // Open and check file
if (!in_stream->good() || !in_stream->is_open()) { if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensor] Failed to open " + in_stream->label()); "[load_safetensors] Failed to open " + in_stream->label());
} }
uint64_t jsonHeaderLength = 0; uint64_t jsonHeaderLength = 0;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8); in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0) { if (jsonHeaderLength <= 0) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensor] Invalid json header length " + in_stream->label()); "[load_safetensors] Invalid json header length " + in_stream->label());
} }
// Load the json metadata // Load the json metadata
char rawJson[jsonHeaderLength]; char rawJson[jsonHeaderLength];
@ -92,7 +92,7 @@ std::unordered_map<std::string, array> load_safetensor(
// Should always be an object on the top-level // Should always be an object on the top-level
if (!metadata.is_object()) { if (!metadata.is_object()) {
throw std::runtime_error( throw std::runtime_error(
"[load_safetensor] Invalid json metadata " + in_stream->label()); "[load_safetensors] Invalid json metadata " + in_stream->label());
} }
size_t offset = jsonHeaderLength + 8; size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata // Load the arrays using metadata
@ -117,14 +117,14 @@ std::unordered_map<std::string, array> load_safetensor(
return res; return res;
} }
std::unordered_map<std::string, array> load_safetensor( std::unordered_map<std::string, array> load_safetensors(
const std::string& file, const std::string& file,
StreamOrDevice s) { StreamOrDevice s) {
return load_safetensor(std::make_shared<io::FileReader>(file), s); return load_safetensors(std::make_shared<io::FileReader>(file), s);
} }
/** Save array to out stream in .npy format */ /** Save array to out stream in .npy format */
void save_safetensor( void save_safetensors(
std::shared_ptr<io::Writer> out_stream, std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph_) { std::optional<bool> retain_graph_) {
@ -132,7 +132,7 @@ void save_safetensor(
// Check file // Check file
if (!out_stream->good() || !out_stream->is_open()) { if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error( throw std::runtime_error(
"[save_safetensor] Failed to open " + out_stream->label()); "[save_safetensors] Failed to open " + out_stream->label());
} }
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
@ -146,12 +146,12 @@ void save_safetensor(
arr.eval(retain_graph_.value_or(arr.is_tracer())); arr.eval(retain_graph_.value_or(arr.is_tracer()));
if (arr.nbytes() == 0) { if (arr.nbytes() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensor] cannot serialize an empty array key: " + key); "[save_safetensors] cannot serialize an empty array key: " + key);
} }
if (!arr.flags().contiguous) { if (!arr.flags().contiguous) {
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensor] cannot serialize a non-contiguous array key: " + "[save_safetensors] cannot serialize a non-contiguous array key: " +
key); key);
} }
json child; json child;
@ -171,7 +171,7 @@ void save_safetensor(
} }
} }
void save_safetensor( void save_safetensors(
const std::string& file_, const std::string& file_,
std::unordered_map<std::string, array> a, std::unordered_map<std::string, array> a,
std::optional<bool> retain_graph) { std::optional<bool> retain_graph) {
@ -184,7 +184,7 @@ void save_safetensor(
file += ".safetensors"; file += ".safetensors";
// Serialize array // Serialize array
save_safetensor(std::make_shared<io::FileWriter>(file), a, retain_graph); save_safetensors(std::make_shared<io::FileWriter>(file), a, retain_graph);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -1057,19 +1057,19 @@ array dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Load array map from .safetensor file format */ /** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensor( std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream, std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {}); StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensor( std::unordered_map<std::string, array> load_safetensors(
const std::string& file, const std::string& file,
StreamOrDevice s = {}); StreamOrDevice s = {});
void save_safetensor( void save_safetensors(
std::shared_ptr<io::Writer> in_stream, std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt); std::optional<bool> retain_graph = std::nullopt);
void save_safetensor( void save_safetensors(
const std::string& file, const std::string& file,
std::unordered_map<std::string, array>, std::unordered_map<std::string, array>,
std::optional<bool> retain_graph = std::nullopt); std::optional<bool> retain_graph = std::nullopt);

View File

@ -164,10 +164,10 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file, py::object file,
StreamOrDevice s) { StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return {load_safetensor(py::cast<std::string>(file), s)}; return {load_safetensors(py::cast<std::string>(file), s)};
} else if (is_istream_object(file)) { } else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately // If we don't own the stream and it was passed to us, eval immediately
auto arr = load_safetensor(std::make_shared<PyFileReader>(file), s); auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
for (auto& [key, arr] : arr) { for (auto& [key, arr] : arr) {
@ -178,7 +178,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
} }
throw std::invalid_argument( throw std::invalid_argument(
"[load_safetensor] Input must be a file-like object, or string"); "[load_safetensors] Input must be a file-like object, or string");
} }
std::unordered_map<std::string, array> mlx_load_npz_helper( std::unordered_map<std::string, array> mlx_load_npz_helper(
@ -427,18 +427,18 @@ void mlx_save_safetensor_helper(
std::optional<bool> retain_graph) { std::optional<bool> retain_graph) {
auto arrays_map = d.cast<std::unordered_map<std::string, array>>(); auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
if (py::isinstance<py::str>(file)) { if (py::isinstance<py::str>(file)) {
save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph); save_safetensors(py::cast<std::string>(file), arrays_map, retain_graph);
return; return;
} else if (is_ostream_object(file)) { } else if (is_ostream_object(file)) {
auto writer = std::make_shared<PyFileWriter>(file); auto writer = std::make_shared<PyFileWriter>(file);
{ {
py::gil_scoped_release gil; py::gil_scoped_release gil;
save_safetensor(writer, arrays_map, retain_graph); save_safetensors(writer, arrays_map, retain_graph);
} }
return; return;
} }
throw std::invalid_argument( throw std::invalid_argument(
"[save_safetensor] Input must be a file-like object, or string"); "[save_safetensors] Input must be a file-like object, or string");
} }

View File

@ -2867,11 +2867,9 @@ void init_ops(py::module_& m) {
Args: Args:
file (str): File to which the array is saved file (str): File to which the array is saved
arr (array): Array to be saved. arr (array): Array to be saved.
retain_graph (bool, optional): Optional argument to retain graph retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation before saving. If not provided the graph during array evaluation. If left unspecified the graph is retained
is retained if we are during a function transformation. Default: only if saving is done in a function transformation. Default: ``None``
None
)pbdoc"); )pbdoc");
m.def( m.def(
"savez", "savez",
@ -2941,32 +2939,36 @@ void init_ops(py::module_& m) {
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format. Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
Args: Args:
file (file, str): File in which the array is saved file (file, str): File in which the array is saved.
format (str, optional): Format of the file. If ``None``, the format format (str, optional): Format of the file. If ``None``, the format
is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. (default: ``None``) is inferred from the file extension. Supported formats: ``npy``,
``npz``, and ``safetensors``. Default: ``None``.
Returns: Returns:
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` or ``.safetensors`` file result (array, dict):
A single array if loading from a ``.npy`` file or a dict mapping
names to arrays if loading from a ``.npz`` or ``.safetensors`` file.
)pbdoc"); )pbdoc");
m.def( m.def(
"save_safetensor", "save_safetensors",
&mlx_save_safetensor_helper, &mlx_save_safetensor_helper,
"file"_a, "file"_a,
"d"_a, "arrays"_a,
py::pos_only(), py::pos_only(),
"retain_graph"_a = std::nullopt, "retain_graph"_a = std::nullopt,
py::kw_only(), py::kw_only(),
R"pbdoc( R"pbdoc(
save_safetensor(file: str, d: Dict[str, array], /, retain_graph: Optional[bool] = None, *) save_safetensors(file: str, arrays: Dict[str, array], /, retain_graph: Optional[bool] = None)
Save array(s) to a binary file in ``.safetensors`` format. Save array(s) to a binary file in ``.safetensors`` format.
For more information on the format see https://huggingface.co/docs/safetensors/index. For more information on the format see https://huggingface.co/docs/safetensors/index.
Args: Args:
file (file, str): File in which the array is saved file (file, str): File in which the array is saved>
d (Dict[str, array]): The dict mapping name to array to be saved arrays (dict(str, array)): The dictionary of names to arrays to be saved.
retain_graph(Optional[bool]): Optional argument to retain graph retain_graph (bool, optional): Whether or not to retain the graph
during array evaluation before saving. If not provided the graph during array evaluation. If left unspecified the graph is retained
is retained if we are during a function transformation. Default: None only if saving is done in a function transformation. Default: ``None``.
)pbdoc"); )pbdoc");
m.def( m.def(
"where", "where",

View File

@ -64,7 +64,7 @@ class TestLoad(mlx_tests.MLXTestCase):
load_arr_mlx_npy = np.load(save_file_mlx) load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy)) self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_safetensor(self): def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir): if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir) os.mkdir(self.test_dir)
@ -82,7 +82,7 @@ class TestLoad(mlx_tests.MLXTestCase):
} }
with open(save_file_mlx, "wb") as f: with open(save_file_mlx, "wb") as f:
mx.save_safetensor(f, save_dict) mx.save_safetensors(f, save_dict)
with open(save_file_mlx, "rb") as f: with open(save_file_mlx, "rb") as f:
load_dict = mx.load(f) load_dict = mx.load(f)

View File

@ -14,13 +14,13 @@ std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name); return std::filesystem::temp_directory_path().append(name);
} }
TEST_CASE("test save_safetensor") { TEST_CASE("test save_safetensors") {
std::string file_path = get_temp_file("test_arr.safetensors"); std::string file_path = get_temp_file("test_arr.safetensors");
auto map = std::unordered_map<std::string, array>(); auto map = std::unordered_map<std::string, array>();
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); map.insert({"test", array({1.0, 2.0, 3.0, 4.0})});
map.insert({"test2", ones({2, 2})}); map.insert({"test2", ones({2, 2})});
save_safetensor(file_path, map); save_safetensors(file_path, map);
auto safeDict = load_safetensor(file_path); auto safeDict = load_safetensors(file_path);
CHECK_EQ(safeDict.size(), 2); CHECK_EQ(safeDict.size(), 2);
CHECK_EQ(safeDict.count("test"), 1); CHECK_EQ(safeDict.count("test"), 1);
CHECK_EQ(safeDict.count("test2"), 1); CHECK_EQ(safeDict.count("test2"), 1);