mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-16 23:41:14 +08:00
LoRA: add training callbacks (#414)
* LoRA: add training callbacks * LoRA: add trained tokens print & callback
This commit is contained in:
parent
726b1ddec0
commit
0ba466369f
@ -7,7 +7,6 @@ import mlx.optimizers as optim
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from .tuner.lora import LoRALinear
|
|
||||||
from .tuner.trainer import TrainingArgs, evaluate, train
|
from .tuner.trainer import TrainingArgs, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import linear_to_lora_layers
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
@ -119,6 +119,17 @@ def evaluate(
|
|||||||
return np.sum(all_losses) / ntokens
|
return np.sum(all_losses) / ntokens
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingCallback:
|
||||||
|
|
||||||
|
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
|
||||||
|
"""Called to report training loss at specified intervals."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_val_loss_report(self, steps: int, loss: float, val_time: float):
|
||||||
|
"""Called to report validation loss at specified intervals or the beginning."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -127,8 +138,11 @@ def train(
|
|||||||
val_dataset,
|
val_dataset,
|
||||||
args: TrainingArgs = TrainingArgs(),
|
args: TrainingArgs = TrainingArgs(),
|
||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches
|
iterate_batches: callable = iterate_batches,
|
||||||
|
training_callback=None,
|
||||||
):
|
):
|
||||||
|
print(f"Starting training..., iters: {args.iters}")
|
||||||
|
|
||||||
# Create checkpoints directory if it does not exist
|
# Create checkpoints directory if it does not exist
|
||||||
if not os.path.exists("checkpoints"):
|
if not os.path.exists("checkpoints"):
|
||||||
os.makedirs("checkpoints")
|
os.makedirs("checkpoints")
|
||||||
@ -138,7 +152,7 @@ def train(
|
|||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
print("Starting training..., iters:", args.iters)
|
trained_tokens = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
@ -168,11 +182,19 @@ def train(
|
|||||||
train_loss = np.mean(losses)
|
train_loss = np.mean(losses)
|
||||||
|
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
|
trained_tokens += n_tokens
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||||
f"It/sec {args.steps_per_report / (stop - start):.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
|
f"Trained Tokens {trained_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if training_callback is not None:
|
||||||
|
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@ -190,12 +212,16 @@ def train(
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches
|
iterate_batches=iterate_batches
|
||||||
)
|
)
|
||||||
|
val_time = time.perf_counter() - stop
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: "
|
f"Iter {it + 1}: "
|
||||||
f"Val loss {val_loss:.3f}, "
|
f"Val loss {val_loss:.3f}, "
|
||||||
f"Val took {(time.perf_counter() - stop):.3f}s"
|
f"Val took {val_time:.3f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if training_callback is not None:
|
||||||
|
training_callback.on_val_loss_report(it + 1, val_loss, val_time)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights if needed
|
# Save adapter weights if needed
|
||||||
@ -205,6 +231,7 @@ def train(
|
|||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# save final adapter weights
|
# save final adapter weights
|
||||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||||
|
Loading…
Reference in New Issue
Block a user