mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-20 10:20:46 +08:00
Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods
Addresses #1224
This commit is contained in:
parent
7a83077cd7
commit
a928bba375
@ -5,7 +5,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
def iterate_batches(
|
||||
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@ -167,11 +169,13 @@ def evaluate(
|
||||
|
||||
class TrainingCallback:
|
||||
|
||||
def on_train_loss_report(self, train_info: dict):
|
||||
def on_train_loss_report(
|
||||
self, train_info: dict, args: Optional[TrainingArgs] = None
|
||||
):
|
||||
"""Called to report training loss at specified intervals."""
|
||||
pass
|
||||
|
||||
def on_val_loss_report(self, val_info: dict):
|
||||
def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
|
||||
"""Called to report validation loss at specified intervals or the beginning."""
|
||||
pass
|
||||
|
||||
@ -227,6 +231,7 @@ def train(
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
args=args,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
@ -258,7 +263,7 @@ def train(
|
||||
"val_loss": val_loss,
|
||||
"val_time": val_time,
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
training_callback.on_val_loss_report(val_info, args=args)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
@ -301,7 +306,7 @@ def train(
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
}
|
||||
training_callback.on_train_loss_report(train_info)
|
||||
training_callback.on_train_loss_report(train_info, args=args)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
|
Loading…
Reference in New Issue
Block a user