Passing parameterized loss and batching to trainer (#391)

This commit is contained in:
Chime Ogbuji 2024-02-13 10:03:25 -05:00 committed by GitHub
parent 954aa50c54
commit e446598f62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,6 +99,7 @@ def evaluate(
num_batches,
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches
):
all_losses = []
ntokens = 0
@ -126,6 +127,7 @@ def train(
val_dataset,
args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss,
iterate_batches: callable = iterate_batches
):
# Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"):
@ -186,6 +188,7 @@ def train(
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches
)
print(
f"Iter {it + 1}: "