From f1ef378a58b7ae472ff556ad8d68c4700c2f806c Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Sun, 11 Feb 2024 19:23:27 +0400 Subject: [PATCH] Feat: update pre-commit rev (#432) --- .pre-commit-config.yaml | 2 +- llms/mlx_lm/tuner/trainer.py | 6 +++--- stable_diffusion/stable_diffusion/model_io.py | 8 +++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34513281..ea21d896 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.12.1 + rev: 24.1.1 hooks: - id: black - repo: https://github.com/pycqa/isort diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index fcc3e1d0..2d92a98f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -80,9 +80,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) for j in range(batch_size): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[ - j - ] = truncated_length # Update lengths to match truncated lengths + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) batch = mx.array(batch_arr) yield batch[:, :-1], batch[:, 1:], mx.array(lengths) diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index 819754b7..3c35eac4 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -186,9 +186,11 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False): out_channels=config["out_channels"], block_out_channels=config["block_out_channels"], layers_per_block=[config["layers_per_block"]] * n_blocks, - num_attention_heads=[config["attention_head_dim"]] * n_blocks - if isinstance(config["attention_head_dim"], int) - else config["attention_head_dim"], + num_attention_heads=( + [config["attention_head_dim"]] * n_blocks + if isinstance(config["attention_head_dim"], int) + else config["attention_head_dim"] + ), cross_attention_dim=[config["cross_attention_dim"]] * n_blocks, norm_num_groups=config["norm_num_groups"], )