diff --git a/python/src/ops.cpp b/python/src/ops.cpp index db405b127..8a65b1053 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2929,12 +2929,12 @@ void init_ops(py::module_& m) { "load", &mlx_load_helper, "file"_a, - "format"_a = none, py::pos_only(), + "format"_a = none, py::kw_only(), "stream"_a = none, R"pbdoc( - load(file: str, format: Optional[str] = None, /, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] + load(file: str, /, format: Optional[str] = None, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]] Load array(s) from a binary file in ``.npy``, ``.npz``, or ``.safetensors`` format.