mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Basic CircleCI (#449)
* basic style checks for circleci * format * fix config
This commit is contained in:
parent
21e19b5b5a
commit
e4d5630698
35
.circleci/config.yml
Normal file
35
.circleci/config.yml
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
version: 2.1
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
linux_build_and_test:
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:3.9
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Run style checks
|
||||||
|
command: |
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit run --all
|
||||||
|
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
build_and_test:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- linux_build_and_test
|
||||||
|
|
||||||
|
prb:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^pull/\\d+(/head)?$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- linux_build_and_test:
|
||||||
|
requires: [ hold ]
|
@ -99,7 +99,7 @@ def evaluate(
|
|||||||
num_batches,
|
num_batches,
|
||||||
max_seq_length=2048,
|
max_seq_length=2048,
|
||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = []
|
all_losses = []
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
@ -121,7 +121,14 @@ def evaluate(
|
|||||||
|
|
||||||
class TrainingCallback:
|
class TrainingCallback:
|
||||||
|
|
||||||
def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int):
|
def on_train_loss_report(
|
||||||
|
self,
|
||||||
|
steps: int,
|
||||||
|
loss: float,
|
||||||
|
it_sec: float,
|
||||||
|
tokens_sec: float,
|
||||||
|
trained_tokens: int,
|
||||||
|
):
|
||||||
"""Called to report training loss at specified intervals."""
|
"""Called to report training loss at specified intervals."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -193,7 +200,9 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens)
|
training_callback.on_train_loss_report(
|
||||||
|
it + 1, train_loss, it_sec, tokens_sec, trained_tokens
|
||||||
|
)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
@ -210,7 +219,7 @@ def train(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user