mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Basic CircleCI (#449)
* basic style checks for circleci * format * fix config
This commit is contained in:
@@ -99,7 +99,7 @@ def evaluate(
|
||||
num_batches,
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = []
|
||||
ntokens = 0
|
||||
@@ -121,7 +121,14 @@ def evaluate(
|
||||
|
||||
class TrainingCallback:
|
||||
|
||||
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
|
||||
def on_train_loss_report(
|
||||
self,
|
||||
steps: int,
|
||||
loss: float,
|
||||
it_sec: float,
|
||||
tokens_sec: float,
|
||||
trained_tokens: int,
|
||||
):
|
||||
"""Called to report training loss at specified intervals."""
|
||||
pass
|
||||
|
||||
@@ -193,7 +200,9 @@ def train(
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
|
||||
training_callback.on_train_loss_report(
|
||||
it + 1, train_loss, it_sec, tokens_sec, trained_tokens
|
||||
)
|
||||
|
||||
losses = []
|
||||
n_tokens = 0
|
||||
@@ -210,7 +219,7 @@ def train(
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
print(
|
||||
|
Reference in New Issue
Block a user