Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods

Addresses #1224
This commit is contained in:
Chime Ogbuji 2025-01-28 14:50:38 -05:00
parent 7a83077cd7
commit a928bba375

View File

@ -5,7 +5,7 @@ import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size: if len(dataset) < batch_size:
@ -167,11 +169,13 @@ def evaluate(
class TrainingCallback: class TrainingCallback:
def on_train_loss_report(self, train_info: dict): def on_train_loss_report(
self, train_info: dict, args: Optional[TrainingArgs] = None
):
"""Called to report training loss at specified intervals.""" """Called to report training loss at specified intervals."""
pass pass
def on_val_loss_report(self, val_info: dict): def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
"""Called to report validation loss at specified intervals or the beginning.""" """Called to report validation loss at specified intervals or the beginning."""
pass pass
@ -227,6 +231,7 @@ def train(
batch_size=args.batch_size, batch_size=args.batch_size,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
train=True, train=True,
args=args,
), ),
): ):
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
@ -258,7 +263,7 @@ def train(
"val_loss": val_loss, "val_loss": val_loss,
"val_time": val_time, "val_time": val_time,
} }
training_callback.on_val_loss_report(val_info) training_callback.on_val_loss_report(val_info, args=args)
start = time.perf_counter() start = time.perf_counter()
@ -301,7 +306,7 @@ def train(
"trained_tokens": trained_tokens, "trained_tokens": trained_tokens,
"peak_memory": peak_mem, "peak_memory": peak_mem,
} }
training_callback.on_train_loss_report(train_info) training_callback.on_train_loss_report(train_info, args=args)
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0