mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 12:06:51 +08:00
Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods
Addresses #1224
This commit is contained in:
parent
7a83077cd7
commit
a928bba375
@ -5,7 +5,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
|
|||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(
|
||||||
|
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
|
||||||
|
):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -167,11 +169,13 @@ def evaluate(
|
|||||||
|
|
||||||
class TrainingCallback:
|
class TrainingCallback:
|
||||||
|
|
||||||
def on_train_loss_report(self, train_info: dict):
|
def on_train_loss_report(
|
||||||
|
self, train_info: dict, args: Optional[TrainingArgs] = None
|
||||||
|
):
|
||||||
"""Called to report training loss at specified intervals."""
|
"""Called to report training loss at specified intervals."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_val_loss_report(self, val_info: dict):
|
def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
|
||||||
"""Called to report validation loss at specified intervals or the beginning."""
|
"""Called to report validation loss at specified intervals or the beginning."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -227,6 +231,7 @@ def train(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
train=True,
|
train=True,
|
||||||
|
args=args,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
@ -258,7 +263,7 @@ def train(
|
|||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
"val_time": val_time,
|
"val_time": val_time,
|
||||||
}
|
}
|
||||||
training_callback.on_val_loss_report(val_info)
|
training_callback.on_val_loss_report(val_info, args=args)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
@ -301,7 +306,7 @@ def train(
|
|||||||
"trained_tokens": trained_tokens,
|
"trained_tokens": trained_tokens,
|
||||||
"peak_memory": peak_mem,
|
"peak_memory": peak_mem,
|
||||||
}
|
}
|
||||||
training_callback.on_train_loss_report(train_info)
|
training_callback.on_train_loss_report(train_info, args=args)
|
||||||
|
|
||||||
losses = 0
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user