mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-16 07:21:12 +08:00
LoRA: add training callbacks (#414)
* LoRA: add training callbacks * LoRA: add trained tokens print & callback
This commit is contained in:
parent
726b1ddec0
commit
0ba466369f
@ -7,7 +7,6 @@ import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import generate, load
|
||||
|
@ -119,6 +119,17 @@ def evaluate(
|
||||
return np.sum(all_losses) / ntokens
|
||||
|
||||
|
||||
class TrainingCallback:
|
||||
|
||||
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
|
||||
"""Called to report training loss at specified intervals."""
|
||||
pass
|
||||
|
||||
def on_val_loss_report(self, steps: int, loss: float, val_time: float):
|
||||
"""Called to report validation loss at specified intervals or the beginning."""
|
||||
pass
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -127,8 +138,11 @@ def train(
|
||||
val_dataset,
|
||||
args: TrainingArgs = TrainingArgs(),
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches
|
||||
iterate_batches: callable = iterate_batches,
|
||||
training_callback=None,
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
|
||||
# Create checkpoints directory if it does not exist
|
||||
if not os.path.exists("checkpoints"):
|
||||
os.makedirs("checkpoints")
|
||||
@ -138,7 +152,7 @@ def train(
|
||||
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
print("Starting training..., iters:", args.iters)
|
||||
trained_tokens = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
@ -168,11 +182,19 @@ def train(
|
||||
train_loss = np.mean(losses)
|
||||
|
||||
stop = time.perf_counter()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
|
||||
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
|
||||
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
start = time.perf_counter()
|
||||
@ -190,12 +212,16 @@ def train(
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {(time.perf_counter() - stop):.3f}s"
|
||||
f"Val took {val_time:.3f}s"
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_val_loss_report(it + 1, val_loss, val_time)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
@ -205,6 +231,7 @@ def train(
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
||||
)
|
||||
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||
|
Loading…
Reference in New Issue
Block a user