Passing parameterized loss and batching to trainer (#391)

This commit is contained in:
Chime Ogbuji 2024-02-13 10:03:25 -05:00 committed by GitHub
parent 954aa50c54
commit e446598f62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,6 +99,7 @@ def evaluate(
num_batches, num_batches,
max_seq_length=2048, max_seq_length=2048,
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches
): ):
all_losses = [] all_losses = []
ntokens = 0 ntokens = 0
@ -126,6 +127,7 @@ def train(
val_dataset, val_dataset,
args: TrainingArgs = TrainingArgs(), args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches
): ):
# Create checkpoints directory if it does not exist # Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"): if not os.path.exists("checkpoints"):
@ -186,6 +188,7 @@ def train(
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches
) )
print( print(
f"Iter {it + 1}: " f"Iter {it + 1}: "