Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods

Addresses #1224
This commit is contained in:
Chime Ogbuji 2025-01-28 14:50:38 -05:00
parent 7a83077cd7
commit a928bba375

View File

@ -5,7 +5,7 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import Optional, Union
import mlx.core as mx
import mlx.nn as nn
@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
def iterate_batches(
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
@ -167,11 +169,13 @@ def evaluate(
class TrainingCallback:
def on_train_loss_report(self, train_info: dict):
def on_train_loss_report(
self, train_info: dict, args: Optional[TrainingArgs] = None
):
"""Called to report training loss at specified intervals."""
pass
def on_val_loss_report(self, val_info: dict):
def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
"""Called to report validation loss at specified intervals or the beginning."""
pass
@ -227,6 +231,7 @@ def train(
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
args=args,
),
):
# Report validation loss if needed, the first validation loss
@ -258,7 +263,7 @@ def train(
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
training_callback.on_val_loss_report(val_info, args=args)
start = time.perf_counter()
@ -301,7 +306,7 @@ def train(
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
training_callback.on_train_loss_report(train_info, args=args)
losses = 0
n_tokens = 0