mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
FLUX: fix pre-commit lints
This commit is contained in:
parent
1c43a83280
commit
39fd6d272f
@ -1,16 +1,17 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import time
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from .datasets import load_dataset
|
||||
from .flux import FluxPipeline
|
||||
@ -187,7 +188,6 @@ if __name__ == "__main__":
|
||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def single_step(x, t5_feat, clip_feat, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -198,14 +198,12 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -214,7 +212,6 @@ if __name__ == "__main__":
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
return loss, grads
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
@ -230,7 +227,6 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# We simply route to the appropriate step based on whether we have
|
||||
# gradients from a previous step and whether we should be performing an
|
||||
# update or simply computing and accumulating gradients in this step.
|
||||
@ -253,7 +249,6 @@ if __name__ == "__main__":
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
@ -273,7 +268,7 @@ if __name__ == "__main__":
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
print(
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
@ -96,7 +97,7 @@ def main():
|
||||
# First we get and eval the conditioning
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3
|
||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
@ -111,15 +112,15 @@ def main():
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the flow transformer.
|
||||
del flux.flow
|
||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3
|
||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size))
|
||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3
|
||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
|
||||
peak_mem_overall = max(
|
||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||
)
|
||||
|
@ -39,14 +39,14 @@ setup(
|
||||
url="https://github.com/ml-explore/mlx-examples",
|
||||
license="MIT",
|
||||
install_requires=requirements,
|
||||
|
||||
# Package configuration
|
||||
packages=find_namespace_packages(include=["mlx_flux", "mlx_flux.*"]), # 明确指定包含的包
|
||||
packages=find_namespace_packages(
|
||||
include=["mlx_flux", "mlx_flux.*"]
|
||||
), # 明确指定包含的包
|
||||
package_data={
|
||||
"mlx_flux": ["*.py"],
|
||||
},
|
||||
include_package_data=True,
|
||||
|
||||
python_requires=">=3.8",
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
|
Loading…
Reference in New Issue
Block a user