diff --git a/python/src/load.cpp b/python/src/load.cpp index 1a52930b2..82b018189 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -161,6 +161,27 @@ class PyFileReader : public io::Reader { py::object tell_func_; }; +std::unordered_map mlx_load_safetensor_helper( + py::object file, + StreamOrDevice s) { + if (py::isinstance(file)) { // Assume .safetensors file path string + return {load_safetensor(py::cast(file), s)}; + } else if (is_istream_object(file)) { + // If we don't own the stream and it was passed to us, eval immediately + auto arr = load_safetensor(std::make_shared(file), s); + { + py::gil_scoped_release gil; + for (auto& [key, arr] : arr) { + arr.eval(); + } + } + return {arr}; + } + + throw std::invalid_argument( + "[load] Input must be a file-like object, string, or pathlib.Path"); +} + DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { py::module_ zipfile = py::module_::import("zipfile"); diff --git a/python/src/load.h b/python/src/load.h index 8f64a64d1..20f0c79dd 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -12,6 +12,9 @@ using namespace mlx::core; using DictOrArray = std::variant>; +std::unordered_map mlx_load_safetensor_helper( + py::object file, + StreamOrDevice s); DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); void mlx_save_helper( py::object file, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 23a6ec2c6..a49b82317 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2945,6 +2945,24 @@ void init_ops(py::module_& m) { Returns: result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file )pbdoc"); + m.def( + "load_safetensor", + &mlx_load_safetensor_helper, + "file"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + load_safetensor(file: str, /, *, stream: Union[None, Stream, Device] = None) -> Dict[str, array] + + Load array(s) from a binary file in ``.safetensors`` format. + + Args: + file (file, str): File in which the array is saved + + Returns: + result dict: The loaded dict mapping name to array from the ``.safetensors`` file + )pbdoc"); m.def( "where", [](const ScalarOrArray& condition,