From 39fd6d272f2464dd400aa63be1bad0e6ece59c9b Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 7 Nov 2024 12:51:22 +0800 Subject: [PATCH] FLUX: fix pre-commit lints --- flux/mlx_flux/dreambooth.py | 17 ++++++----------- flux/mlx_flux/txt2image.py | 9 +++++---- flux/setup.py | 6 +++--- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/flux/mlx_flux/dreambooth.py b/flux/mlx_flux/dreambooth.py index 8327af09..91049cb1 100644 --- a/flux/mlx_flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -1,16 +1,17 @@ # Copyright © 2024 Apple Inc. import argparse +import time +from functools import partial +from pathlib import Path + import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np -import time -from PIL import Image -from functools import partial from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce -from pathlib import Path +from PIL import Image from .datasets import load_dataset from .flux import FluxPipeline @@ -187,7 +188,6 @@ if __name__ == "__main__": optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] - @partial(mx.compile, inputs=state, outputs=state) def single_step(x, t5_feat, clip_feat, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -198,14 +198,12 @@ if __name__ == "__main__": return loss - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): return nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) - @partial(mx.compile, inputs=state, outputs=state) def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -214,7 +212,6 @@ if __name__ == "__main__": grads = tree_map(lambda a, b: a + b, prev_grads, grads) return loss, grads - @partial(mx.compile, inputs=state, outputs=state) def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( @@ -230,7 +227,6 @@ if __name__ == "__main__": return loss - # We simply route to the appropriate step based on whether we have # gradients from a previous step and whether we should be performing an # update or simply computing and accumulating gradients in this step. @@ -253,7 +249,6 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - dataset = load_dataset(args.dataset) trainer = Trainer(flux, dataset, args) trainer.encode_dataset() @@ -273,7 +268,7 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: toc = time.time() - peak_mem = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem = mx.metal.get_peak_memory() / 1024**3 print( f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " diff --git a/flux/mlx_flux/txt2image.py b/flux/mlx_flux/txt2image.py index 358e8cad..74892d6d 100644 --- a/flux/mlx_flux/txt2image.py +++ b/flux/mlx_flux/txt2image.py @@ -1,6 +1,7 @@ # Copyright © 2024 Apple Inc. import argparse + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -96,7 +97,7 @@ def main(): # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # The following is not necessary but it may help in memory constrained @@ -111,15 +112,15 @@ def main(): # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 mx.metal.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): - decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size)) + decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3 + peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) diff --git a/flux/setup.py b/flux/setup.py index 7a92f78b..e8235b15 100644 --- a/flux/setup.py +++ b/flux/setup.py @@ -39,14 +39,14 @@ setup( url="https://github.com/ml-explore/mlx-examples", license="MIT", install_requires=requirements, - # Package configuration - packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包 + packages=find_namespace_packages( + include=["mlx_flux", "mlx_flux.*"] + ), # 明确指定包含的包 package_data={ "mlx_flux": ["*.py"], }, include_package_data=True, - python_requires=">=3.8", entry_points={ "console_scripts": [