mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Couple fixes for LoRA (#711)
* don't overwrite in test only mode * only load model specific safetensors
This commit is contained in:
@@ -14,7 +14,7 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers
|
||||||
from .utils import load, save_config
|
from .utils import load, save_config
|
||||||
|
|
||||||
yaml_loader = yaml.SafeLoader
|
yaml_loader = yaml.SafeLoader
|
||||||
@@ -170,6 +170,17 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
|
|
||||||
# Freeze all layers
|
# Freeze all layers
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
|
||||||
|
adapter_path = Path(args.adapter_path)
|
||||||
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
|
|
||||||
|
if args.test and not args.train:
|
||||||
|
apply_lora_layers(model, adapter_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
# Convert linear layers to lora layers and unfreeze in the process
|
# Convert linear layers to lora layers and unfreeze in the process
|
||||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||||
|
|
||||||
@@ -183,11 +194,6 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
adapter_path = Path(args.adapter_path)
|
|
||||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
|
||||||
adapter_file = adapter_path / "adapters.safetensors"
|
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
# init training args
|
# init training args
|
||||||
@@ -222,14 +228,6 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
training_callback=training_callback,
|
training_callback=training_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the LoRA adapter weights which we assume should exist by this point
|
|
||||||
if not adapter_file.is_file():
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter file {adapter_file} missing. "
|
|
||||||
"Use --train to learn and save the adapters"
|
|
||||||
)
|
|
||||||
model.load_weights(str(adapter_file), strict=False)
|
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
print("Testing")
|
print("Testing")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@@ -311,7 +311,12 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
|
|
||||||
config = load_config(model_path)
|
config = load_config(model_path)
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||||
|
|
||||||
|
if not weight_files:
|
||||||
|
# Try weight for back-compat
|
||||||
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
Reference in New Issue
Block a user