mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
Passing parameterized loss and batching to trainer (#391)
This commit is contained in:
parent
954aa50c54
commit
e446598f62
@ -99,6 +99,7 @@ def evaluate(
|
|||||||
num_batches,
|
num_batches,
|
||||||
max_seq_length=2048,
|
max_seq_length=2048,
|
||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
|
iterate_batches: callable = iterate_batches
|
||||||
):
|
):
|
||||||
all_losses = []
|
all_losses = []
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
@ -126,6 +127,7 @@ def train(
|
|||||||
val_dataset,
|
val_dataset,
|
||||||
args: TrainingArgs = TrainingArgs(),
|
args: TrainingArgs = TrainingArgs(),
|
||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
|
iterate_batches: callable = iterate_batches
|
||||||
):
|
):
|
||||||
# Create checkpoints directory if it does not exist
|
# Create checkpoints directory if it does not exist
|
||||||
if not os.path.exists("checkpoints"):
|
if not os.path.exists("checkpoints"):
|
||||||
@ -186,6 +188,7 @@ def train(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
|
iterate_batches=iterate_batches
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: "
|
f"Iter {it + 1}: "
|
||||||
|
Loading…
Reference in New Issue
Block a user