Change Load to be an IOPrimitive

This commit is contained in:
Angelos Katharopoulos
2024-05-08 18:59:20 -07:00
parent c8e2b42ced
commit b193741050
13 changed files with 101 additions and 29 deletions

View File

@@ -162,10 +162,10 @@ std::pair<
std::unordered_map<std::string, std::string>>
mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .safetensors file path string
return load_safetensors(nb::cast<std::string>(file), s);
return load_safetensors(nb::cast<std::string>(file));
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto res = load_safetensors(std::make_shared<PyFileReader>(file), s);
auto res = load_safetensors(std::make_shared<PyFileReader>(file));
{
nb::gil_scoped_release gil;
for (auto& [key, arr] : std::get<0>(res)) {
@@ -181,7 +181,7 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) {
GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) {
if (nb::isinstance<nb::str>(file)) { // Assume .gguf file path string
return load_gguf(nb::cast<std::string>(file), s);
return load_gguf(nb::cast<std::string>(file));
}
throw std::invalid_argument("[load_gguf] Input must be a string");