mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 11:48:39 +08:00
fix optimizer
This commit is contained in:
parent
700c3ef5cc
commit
73cc094681
@ -255,11 +255,21 @@ def train_model(
|
|||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
opt = optim.Adam( # need to correct that part
|
|
||||||
learning_rate=(
|
# Initialize the selected optimizer
|
||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
|
||||||
)
|
optimizer_name = args.optimizer.lower()
|
||||||
|
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||||
|
|
||||||
|
if optimizer_name == "adam":
|
||||||
|
opt_class = optim.Adam
|
||||||
|
elif optimizer_name == "adamw":
|
||||||
|
opt_class = optim.AdamW
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||||
|
|
||||||
|
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||||
|
|
||||||
# Train model based on training mode
|
# Train model based on training mode
|
||||||
if args.training_mode == "orpo":
|
if args.training_mode == "orpo":
|
||||||
@ -299,21 +309,6 @@ def train_model(
|
|||||||
grad_checkpoint=args.grad_checkpoint
|
grad_checkpoint=args.grad_checkpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the selected optimizer
|
|
||||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
|
||||||
|
|
||||||
optimizer_name = args.optimizer.lower()
|
|
||||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
|
||||||
|
|
||||||
if optimizer_name == "adam":
|
|
||||||
opt_class = optim.Adam
|
|
||||||
elif optimizer_name == "adamw":
|
|
||||||
opt_class = optim.AdamW
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
|
||||||
|
|
||||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
|
||||||
|
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user