mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:38:07 +08:00
docs and made retain_graph optional bool
This commit is contained in:
parent
fa093967ec
commit
ee6ce00aee
@ -127,7 +127,7 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
std::shared_ptr<io::Writer> out_stream,
|
std::shared_ptr<io::Writer> out_stream,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a,
|
||||||
bool retain_graph) {
|
std::optional<bool> retain_graph_) {
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check file
|
// Check file
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
@ -143,7 +143,7 @@ void save_safetensor(
|
|||||||
});
|
});
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& [key, arr] : a) {
|
for (auto& [key, arr] : a) {
|
||||||
arr.eval(retain_graph);
|
arr.eval(retain_graph_.value_or(arr.is_tracer()));
|
||||||
if (arr.nbytes() == 0) {
|
if (arr.nbytes() == 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[save_safetensor] cannot serialize an empty array key: " + key);
|
"[save_safetensor] cannot serialize an empty array key: " + key);
|
||||||
@ -174,7 +174,7 @@ void save_safetensor(
|
|||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
const std::string& file_,
|
const std::string& file_,
|
||||||
std::unordered_map<std::string, array> a,
|
std::unordered_map<std::string, array> a,
|
||||||
bool retain_graph) {
|
std::optional<bool> retain_graph) {
|
||||||
// Open and check file
|
// Open and check file
|
||||||
std::string file = file_;
|
std::string file = file_;
|
||||||
|
|
||||||
|
@ -1068,9 +1068,9 @@ std::unordered_map<std::string, array> load_safetensor(
|
|||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
std::shared_ptr<io::Writer> in_stream,
|
std::shared_ptr<io::Writer> in_stream,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
bool retain_graph = true);
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
void save_safetensor(
|
void save_safetensor(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::unordered_map<std::string, array>,
|
std::unordered_map<std::string, array>,
|
||||||
bool retain_graph = true);
|
std::optional<bool> retain_graph = std::nullopt);
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -424,7 +424,7 @@ void mlx_savez_helper(
|
|||||||
void mlx_save_safetensor_helper(
|
void mlx_save_safetensor_helper(
|
||||||
py::object file,
|
py::object file,
|
||||||
py::dict d,
|
py::dict d,
|
||||||
bool retain_graph) {
|
std::optional<bool> retain_graph) {
|
||||||
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
auto arrays_map = d.cast<std::unordered_map<std::string, array>>();
|
||||||
if (py::isinstance<py::str>(file)) {
|
if (py::isinstance<py::str>(file)) {
|
||||||
save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph);
|
save_safetensor(py::cast<std::string>(file), arrays_map, retain_graph);
|
||||||
|
@ -2938,14 +2938,14 @@ void init_ops(py::module_& m) {
|
|||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]
|
||||||
|
|
||||||
Load array(s) from a binary file in ``.npy`` or ``.npz`` format.
|
Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved
|
||||||
format (str, optional): Format of the file. If ``None``, the format
|
format (str, optional): Format of the file. If ``None``, the format
|
||||||
is inferred from the file extension. Supported formats: npy, npz, and safetensors. (default: ``None``)
|
is inferred from the file extension. Supported formats: ``npy``, ``npz``, and ``safetensors``. (default: ``None``)
|
||||||
Returns:
|
Returns:
|
||||||
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file
|
result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` or ``.safetensors`` file
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"save_safetensor",
|
"save_safetensor",
|
||||||
@ -2953,18 +2953,20 @@ void init_ops(py::module_& m) {
|
|||||||
"file"_a,
|
"file"_a,
|
||||||
"d"_a,
|
"d"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"retain_graph"_a = true,
|
"retain_graph"_a = std::nullopt,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
save_safetensor(file: str, d: Dict[str, array], /, retain_graph: bool = True, *)
|
save_safetensor(file: str, d: Dict[str, array], /, retain_graph: Optional[bool] = None, *)
|
||||||
|
|
||||||
Save array(s) to a binary file in ``.safetensors`` format.
|
Save array(s) to a binary file in ``.safetensors`` format.
|
||||||
|
For more information on the format see https://huggingface.co/docs/safetensors/index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): File in which the array is saved
|
file (file, str): File in which the array is saved
|
||||||
d (Dict[str, array]): The dict mapping name to array to be saved
|
d (Dict[str, array]): The dict mapping name to array to be saved
|
||||||
retain_graph(bool): Optional argument to retain graph
|
retain_graph(Optional[bool]): Optional argument to retain graph
|
||||||
during array evaluation before saving. Default: True
|
during array evaluation before saving. If not provided the graph
|
||||||
|
is retained if we are during a function transformation. Default: None
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"where",
|
"where",
|
||||||
|
Loading…
Reference in New Issue
Block a user