Basic CircleCI (#449)

* basic style checks for circleci

* format

* fix config
This commit is contained in:
Awni Hannun
2024-02-16 22:13:55 -08:00
committed by GitHub
parent 21e19b5b5a
commit e4d5630698
2 changed files with 48 additions and 4 deletions

View File

@@ -99,7 +99,7 @@ def evaluate(
num_batches,
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches
iterate_batches: callable = iterate_batches,
):
all_losses = []
ntokens = 0
@@ -121,7 +121,14 @@ def evaluate(
class TrainingCallback:
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
def on_train_loss_report(
self,
steps: int,
loss: float,
it_sec: float,
tokens_sec: float,
trained_tokens: int,
):
"""Called to report training loss at specified intervals."""
pass
@@ -193,7 +200,9 @@ def train(
)
if training_callback is not None:
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
training_callback.on_train_loss_report(
it + 1, train_loss, it_sec, tokens_sec, trained_tokens
)
losses = []
n_tokens = 0
@@ -210,7 +219,7 @@ def train(
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
print(