Couple fixes for LoRA (#711)

* don't overwrite in test only mode

* only load model specific safetensors
This commit is contained in:
Awni Hannun 2024-04-25 14:16:13 -07:00 committed by GitHub
parent 109ee2f2f8
commit 685012c2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 18 deletions

View File

@ -14,7 +14,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import build_schedule, linear_to_lora_layers
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
@ -170,10 +170,21 @@ def run(args, training_callback: TrainingCallback = None):
# Freeze all layers
model.freeze()
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_file = adapter_path / "adapters.safetensors"
if args.test and not args.train:
apply_lora_layers(model, adapter_path)
else:
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
print_trainable_parameters(model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
@ -183,11 +194,6 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json")
adapter_file = adapter_path / "adapters.safetensors"
if args.train:
print("Training")
# init training args
@ -222,14 +228,6 @@ def run(args, training_callback: TrainingCallback = None):
training_callback=training_callback,
)
# Load the LoRA adapter weights which we assume should exist by this point
if not adapter_file.is_file():
raise ValueError(
f"Adapter file {adapter_file} missing. "
"Use --train to learn and save the adapters"
)
model.load_weights(str(adapter_file), strict=False)
if args.test:
print("Testing")
model.eval()

View File

@ -311,7 +311,12 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
config = load_config(model_path)
weight_files = glob.glob(str(model_path / "*.safetensors"))
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files:
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")