Couple fixes for LoRA (#711)

* don't overwrite in test only mode

* only load model specific safetensors
This commit is contained in:
Awni Hannun
2024-04-25 14:16:13 -07:00
committed by GitHub
parent 109ee2f2f8
commit 685012c2ad
2 changed files with 21 additions and 18 deletions

View File

@@ -311,7 +311,12 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
config = load_config(model_path)
weight_files = glob.glob(str(model_path / "*.safetensors"))
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files:
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")