mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 12:09:43 +08:00
Change Load to be an IOPrimitive
This commit is contained in:
@@ -162,10 +162,10 @@ std::pair<
|
||||
std::unordered_map<std::string, std::string>>
|
||||
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
|
||||
return load_safetensors(nb::cast<std::string>(file), s);
|
||||
return load_safetensors(nb::cast<std::string>(file));
|
||||
} else if (is_istream_object(file)) {
|
||||
// If we don't own the stream and it was passed to us, eval immediately
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
|
||||
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
|
||||
{
|
||||
nb::gil_scoped_release gil;
|
||||
for (auto& [key, arr] : std::get<0>(res)) {
|
||||
@@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
|
||||
|
||||
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
|
||||
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
|
||||
return load_gguf(nb::cast<std::string>(file), s);
|
||||
return load_gguf(nb::cast<std::string>(file));
|
||||
}
|
||||
|
||||
throw std::invalid_argument("[load_gguf] Input must be a string");
|
||||
|
||||
Reference in New Issue
Block a user