mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Couple fixes for LoRA (#711)
* don't overwrite in test only mode * only load model specific safetensors
This commit is contained in:
parent
109ee2f2f8
commit
685012c2ad
@ -14,7 +14,7 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
@ -170,10 +170,21 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
|
||||
if args.test and not args.train:
|
||||
apply_lora_layers(model, adapter_path)
|
||||
|
||||
else:
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
@ -183,11 +194,6 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
|
||||
if args.train:
|
||||
print("Training")
|
||||
# init training args
|
||||
@ -222,14 +228,6 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
# Load the LoRA adapter weights which we assume should exist by this point
|
||||
if not adapter_file.is_file():
|
||||
raise ValueError(
|
||||
f"Adapter file {adapter_file} missing. "
|
||||
"Use --train to learn and save the adapters"
|
||||
)
|
||||
model.load_weights(str(adapter_file), strict=False)
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
model.eval()
|
||||
|
@ -311,7 +311,12 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
|
||||
config = load_config(model_path)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
Loading…
Reference in New Issue
Block a user