diff --git a/python/src/load.cpp b/python/src/load.cpp index a2e605811..2590b1821 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -178,7 +178,7 @@ std::unordered_map mlx_load_safetensor_helper( } throw std::invalid_argument( - "[load_safetensor] Input must be a file-like object, string, or pathlib.Path"); + "[load_safetensor] Input must be a file-like object, or string"); } std::unordered_map mlx_load_npz_helper( @@ -233,7 +233,7 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) { return {arr}; } throw std::invalid_argument( - "[load_npy] Input must be a file-like object, string, or pathlib.Path"); + "[load_npy] Input must be a file-like object, or string"); } DictOrArray mlx_load_helper( @@ -244,11 +244,9 @@ DictOrArray mlx_load_helper( std::string fname; if (py::isinstance(file)) { fname = py::cast(file); - } else if (is_istream_object(file)) { - fname = file.attr("name").cast(); } else { throw std::invalid_argument( - "[load] Input must be a file-like object, string, or pathlib.Path"); + "[load] Input must be a file-like object, or string"); } size_t ext = fname.find_last_of('.'); if (ext == std::string::npos) { @@ -364,7 +362,7 @@ void mlx_save_helper( } throw std::invalid_argument( - "[save] Input must be a file-like object, string, or pathlib.Path"); + "[save] Input must be a file-like object, or string"); } void mlx_savez_helper( @@ -440,5 +438,5 @@ void mlx_save_safetensor_helper( } throw std::invalid_argument( - "[save_safetensor] Input must be a file-like object, string, or pathlib.Path"); + "[save_safetensor] Input must be a file-like object, or string"); } \ No newline at end of file diff --git a/python/src/load.h b/python/src/load.h index e34977151..4dc6fcda7 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -20,7 +20,7 @@ std::unordered_map mlx_load_safetensor_helper( void mlx_save_safetensor_helper( py::object file, py::dict d, - bool retain_graph = true); + std::optional retain_graph = std::nullopt); DictOrArray mlx_load_helper( py::object file,