mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix load compilation (#298)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							1f6ab6a556
						
					
				
				
					commit
					79c95b6919
				
			@@ -164,7 +164,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
    py::object file,
 | 
			
		||||
    StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
 | 
			
		||||
    return {load_safetensors(py::cast<std::string>(file), s)};
 | 
			
		||||
    return load_safetensors(py::cast<std::string>(file), s);
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
@@ -174,7 +174,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
 | 
			
		||||
        arr.eval();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return {arr};
 | 
			
		||||
    return arr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
@@ -217,12 +217,12 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
 | 
			
		||||
    arr.eval();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return {array_dict};
 | 
			
		||||
  return array_dict;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
  if (py::isinstance<py::str>(file)) { // Assume .npy file path string
 | 
			
		||||
    return {load(py::cast<std::string>(file), s)};
 | 
			
		||||
    return load(py::cast<std::string>(file), s);
 | 
			
		||||
  } else if (is_istream_object(file)) {
 | 
			
		||||
    // If we don't own the stream and it was passed to us, eval immediately
 | 
			
		||||
    auto arr = load(std::make_shared<PyFileReader>(file), s);
 | 
			
		||||
@@ -230,7 +230,7 @@ array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
 | 
			
		||||
      py::gil_scoped_release gil;
 | 
			
		||||
      arr.eval();
 | 
			
		||||
    }
 | 
			
		||||
    return {arr};
 | 
			
		||||
    return arr;
 | 
			
		||||
  }
 | 
			
		||||
  throw std::invalid_argument(
 | 
			
		||||
      "[load_npy] Input must be a file-like object, or string");
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user