LoRA: add training callbacks (#414)

* LoRA: add training callbacks

* LoRA: add trained tokens print & callback
This commit is contained in:
Madroid Ma 2024-02-16 22:04:57 +08:00 committed by GitHub
parent 726b1ddec0
commit 0ba466369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 6 deletions

View File

@ -7,7 +7,6 @@ import mlx.optimizers as optim
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .tuner.lora import LoRALinear
from .tuner.trainer import TrainingArgs, evaluate, train from .tuner.trainer import TrainingArgs, evaluate, train
from .tuner.utils import linear_to_lora_layers from .tuner.utils import linear_to_lora_layers
from .utils import generate, load from .utils import generate, load

View File

@ -119,6 +119,17 @@ def evaluate(
return np.sum(all_losses) / ntokens return np.sum(all_losses) / ntokens
class TrainingCallback:
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
"""Called to report training loss at specified intervals."""
pass
def on_val_loss_report(self, steps: int, loss: float, val_time: float):
"""Called to report validation loss at specified intervals or the beginning."""
pass
def train( def train(
model, model,
tokenizer, tokenizer,
@ -127,8 +138,11 @@ def train(
val_dataset, val_dataset,
args: TrainingArgs = TrainingArgs(), args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches iterate_batches: callable = iterate_batches,
training_callback=None,
): ):
print(f"Starting training..., iters: {args.iters}")
# Create checkpoints directory if it does not exist # Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"): if not os.path.exists("checkpoints"):
os.makedirs("checkpoints") os.makedirs("checkpoints")
@ -138,7 +152,7 @@ def train(
losses = [] losses = []
n_tokens = 0 n_tokens = 0
print("Starting training..., iters:", args.iters) trained_tokens = 0
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
@ -168,11 +182,19 @@ def train(
train_loss = np.mean(losses) train_loss = np.mean(losses)
stop = time.perf_counter() stop = time.perf_counter()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
print( print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, " f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {args.steps_per_report / (stop - start):.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}"
) )
if training_callback is not None:
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
losses = [] losses = []
n_tokens = 0 n_tokens = 0
start = time.perf_counter() start = time.perf_counter()
@ -190,12 +212,16 @@ def train(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches iterate_batches=iterate_batches
) )
val_time = time.perf_counter() - stop
print( print(
f"Iter {it + 1}: " f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, " f"Val loss {val_loss:.3f}, "
f"Val took {(time.perf_counter() - stop):.3f}s" f"Val took {val_time:.3f}s"
) )
if training_callback is not None:
training_callback.on_val_loss_report(it + 1, val_loss, val_time)
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights if needed # Save adapter weights if needed
@ -205,6 +231,7 @@ def train(
print( print(
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
) )
# save final adapter weights # save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file) save_adapter(model=model, adapter_file=args.adapter_file)
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.") print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")