diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 18840cf4..333e447c 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -14,7 +14,7 @@ from mlx.utils import tree_flatten from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.utils import build_schedule, linear_to_lora_layers +from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers from .utils import load, save_config yaml_loader = yaml.SafeLoader @@ -170,10 +170,21 @@ def run(args, training_callback: TrainingCallback = None): # Freeze all layers model.freeze() - # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) - print_trainable_parameters(model) + adapter_path = Path(args.adapter_path) + adapter_file = adapter_path / "adapters.safetensors" + + if args.test and not args.train: + apply_lora_layers(model, adapter_path) + + else: + adapter_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), adapter_path / "adapter_config.json") + + # Convert linear layers to lora layers and unfreeze in the process + linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) + + print_trainable_parameters(model) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) @@ -183,11 +194,6 @@ def run(args, training_callback: TrainingCallback = None): print(f"Loading pretrained adapters from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) - adapter_path = Path(args.adapter_path) - adapter_path.mkdir(parents=True, exist_ok=True) - save_config(vars(args), adapter_path / "adapter_config.json") - adapter_file = adapter_path / "adapters.safetensors" - if args.train: print("Training") # init training args @@ -222,14 +228,6 @@ def run(args, training_callback: TrainingCallback = None): training_callback=training_callback, ) - # Load the LoRA adapter weights which we assume should exist by this point - if not adapter_file.is_file(): - raise ValueError( - f"Adapter file {adapter_file} missing. " - "Use --train to learn and save the adapters" - ) - model.load_weights(str(adapter_file), strict=False) - if args.test: print("Testing") model.eval() diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2d0767c7..fba4f328 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -311,7 +311,12 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: config = load_config(model_path) - weight_files = glob.glob(str(model_path / "*.safetensors")) + weight_files = glob.glob(str(model_path / "model*.safetensors")) + + if not weight_files: + # Try weight for back-compat + weight_files = glob.glob(str(model_path / "weight*.safetensors")) + if not weight_files: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}")