FLUX: fix pre-commit lints

This commit is contained in:
madroid 2024-11-07 12:51:22 +08:00
parent 1c43a83280
commit 39fd6d272f
3 changed files with 14 additions and 18 deletions

View File

@ -1,16 +1,17 @@
# Copyright © 2024 Apple Inc.
import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import time
from PIL import Image
from functools import partial
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from pathlib import Path
from PIL import Image
from .datasets import load_dataset
from .flux import FluxPipeline
@ -187,7 +188,6 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -198,14 +198,12 @@ if __name__ == "__main__":
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -214,7 +212,6 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -230,7 +227,6 @@ if __name__ == "__main__":
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
@ -253,7 +249,6 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()

View File

@ -1,6 +1,7 @@
# Copyright © 2024 Apple Inc.
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np

View File

@ -39,14 +39,14 @@ setup(
url="https://github.com/ml-explore/mlx-examples",
license="MIT",
install_requires=requirements,
# Package configuration
packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包
packages=find_namespace_packages(
include=["mlx_flux", "mlx_flux.*"]
), # 明确指定包含的包
package_data={
"mlx_flux": ["*.py"],
},
include_package_data=True,
python_requires=">=3.8",
entry_points={
"console_scripts": [