FLUX: fix pre-commit lints

This commit is contained in:
madroid 2024-11-07 12:51:22 +08:00
parent 1c43a83280
commit 39fd6d272f
3 changed files with 14 additions and 18 deletions

View File

@ -1,16 +1,17 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
import time
from PIL import Image
from functools import partial
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from pathlib import Path from PIL import Image
from .datasets import load_dataset from .datasets import load_dataset
from .flux import FluxPipeline from .flux import FluxPipeline
@ -187,7 +188,6 @@ if __name__ == "__main__":
optimizer = optim.Adam(learning_rate=lr_schedule) optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state] state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance): def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -198,14 +198,12 @@ if __name__ == "__main__":
return loss return loss
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)( return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance x, t5_feat, clip_feat, guidance
) )
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -214,7 +212,6 @@ if __name__ == "__main__":
grads = tree_map(lambda a, b: a + b, prev_grads, grads) grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads return loss, grads
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
@ -230,7 +227,6 @@ if __name__ == "__main__":
return loss return loss
# We simply route to the appropriate step based on whether we have # We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an # gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step. # update or simply computing and accumulating gradients in this step.
@ -253,7 +249,6 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads x, t5_feat, clip_feat, guidance, prev_grads
) )
dataset = load_dataset(args.dataset) dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args) trainer = Trainer(flux, dataset, args)
trainer.encode_dataset() trainer.encode_dataset()

View File

@ -1,6 +1,7 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np

View File

@ -39,14 +39,14 @@ setup(
url="https://github.com/ml-explore/mlx-examples", url="https://github.com/ml-explore/mlx-examples",
license="MIT", license="MIT",
install_requires=requirements, install_requires=requirements,
# Package configuration # Package configuration
packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包 packages=find_namespace_packages(
include=["mlx_flux", "mlx_flux.*"]
), # 明确指定包含的包
package_data={ package_data={
"mlx_flux": ["*.py"], "mlx_flux": ["*.py"],
}, },
include_package_data=True, include_package_data=True,
python_requires=">=3.8", python_requires=">=3.8",
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [