LoRA: add training callbacks (#414)

* LoRA: add training callbacks

* LoRA: add trained tokens print & callback
This commit is contained in:
Madroid Ma 2024-02-16 22:04:57 +08:00 committed by GitHub
parent 726b1ddec0
commit 0ba466369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 6 deletions

View File

@ -7,7 +7,6 @@ import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .utils import generate, load

View File

@ -119,6 +119,17 @@ def evaluate(
return np.sum(all_losses) / ntokens
class TrainingCallback:
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
"""Called to report training loss at specified intervals."""
pass
def on_val_loss_report(self, steps: int, loss: float, val_time: float):
"""Called to report validation loss at specified intervals or the beginning."""
pass
def train(
model,
tokenizer,
@ -127,8 +138,11 @@ def train(
val_dataset,
args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss,
iterate_batches: callable = iterate_batches
iterate_batches: callable = iterate_batches,
training_callback=None,
):
print(f"Starting training..., iters: {args.iters}")
# Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
@ -138,7 +152,7 @@ def train(
losses = []
n_tokens = 0
print("Starting training..., iters:", args.iters)
trained_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
@ -168,11 +182,19 @@ def train(
train_loss = np.mean(losses)
stop = time.perf_counter()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}"
)
if training_callback is not None:
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
losses = []
n_tokens = 0
start = time.perf_counter()
@ -190,12 +212,16 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches
)
val_time = time.perf_counter() - stop
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val took {(time.perf_counter() - stop):.3f}s"
f"Val took {val_time:.3f}s"
)
if training_callback is not None:
training_callback.on_val_loss_report(it + 1, val_loss, val_time)
start = time.perf_counter()
# Save adapter weights if needed
@ -205,6 +231,7 @@ def train(
print(
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
)
# save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file)
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")