mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
FLUX: move cli to mlx_flux dir
This commit is contained in:
parent
83c92c2a11
commit
e61849a003
@ -1,19 +1,20 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
from functools import partial
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
from pathlib import Path
|
||||||
|
|
||||||
from mlx_flux import FluxPipeline, Trainer, load_dataset
|
from .datasets import load_dataset
|
||||||
|
from .flux import FluxPipeline
|
||||||
|
from .trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@ -186,6 +187,7 @@ if __name__ == "__main__":
|
|||||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def single_step(x, t5_feat, clip_feat, guidance):
|
def single_step(x, t5_feat, clip_feat, guidance):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -196,12 +198,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
x, t5_feat, clip_feat, guidance
|
x, t5_feat, clip_feat, guidance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -210,6 +214,7 @@ if __name__ == "__main__":
|
|||||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||||
return loss, grads
|
return loss, grads
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
@ -225,6 +230,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# We simply route to the appropriate step based on whether we have
|
# We simply route to the appropriate step based on whether we have
|
||||||
# gradients from a previous step and whether we should be performing an
|
# gradients from a previous step and whether we should be performing an
|
||||||
# update or simply computing and accumulating gradients in this step.
|
# update or simply computing and accumulating gradients in this step.
|
||||||
@ -247,6 +253,7 @@ if __name__ == "__main__":
|
|||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
x, t5_feat, clip_feat, guidance, prev_grads
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
dataset = load_dataset(args.dataset)
|
dataset = load_dataset(args.dataset)
|
||||||
trainer = Trainer(flux, dataset, args)
|
trainer = Trainer(flux, dataset, args)
|
||||||
trainer.encode_dataset()
|
trainer.encode_dataset()
|
||||||
@ -266,7 +273,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024 ** 3
|
||||||
print(
|
print(
|
||||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {10 / (toc - tic):.3f} "
|
@ -1,14 +1,13 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from mlx_flux import FluxPipeline
|
from .flux import FluxPipeline
|
||||||
|
|
||||||
|
|
||||||
def to_latent_size(image_size):
|
def to_latent_size(image_size):
|
||||||
@ -39,7 +38,7 @@ def load_adapter(flux, adapter_file, fuse=False):
|
|||||||
flux.fuse_lora_layers()
|
flux.fuse_lora_layers()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def build_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate images from a textual prompt using stable diffusion"
|
description="Generate images from a textual prompt using stable diffusion"
|
||||||
)
|
)
|
||||||
@ -62,7 +61,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--adapter")
|
parser.add_argument("--adapter")
|
||||||
parser.add_argument("--fuse-adapter", action="store_true")
|
parser.add_argument("--fuse-adapter", action="store_true")
|
||||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
||||||
args = parser.parse_args()
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
|
||||||
# Load the models
|
# Load the models
|
||||||
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
|
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
|
||||||
@ -93,7 +96,7 @@ if __name__ == "__main__":
|
|||||||
# First we get and eval the conditioning
|
# First we get and eval the conditioning
|
||||||
conditioning = next(latents)
|
conditioning = next(latents)
|
||||||
mx.eval(conditioning)
|
mx.eval(conditioning)
|
||||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024 ** 3
|
||||||
mx.metal.reset_peak_memory()
|
mx.metal.reset_peak_memory()
|
||||||
|
|
||||||
# The following is not necessary but it may help in memory constrained
|
# The following is not necessary but it may help in memory constrained
|
||||||
@ -108,15 +111,15 @@ if __name__ == "__main__":
|
|||||||
# The following is not necessary but it may help in memory constrained
|
# The following is not necessary but it may help in memory constrained
|
||||||
# systems by reusing the memory kept by the flow transformer.
|
# systems by reusing the memory kept by the flow transformer.
|
||||||
del flux.flow
|
del flux.flow
|
||||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_generation = mx.metal.get_peak_memory() / 1024 ** 3
|
||||||
mx.metal.reset_peak_memory()
|
mx.metal.reset_peak_memory()
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
decoded = []
|
decoded = []
|
||||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
decoded.append(flux.decode(x_t[i: i + args.decoding_batch_size], latent_size))
|
||||||
mx.eval(decoded[-1])
|
mx.eval(decoded[-1])
|
||||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
|
peak_mem_decoding = mx.metal.get_peak_memory() / 1024 ** 3
|
||||||
peak_mem_overall = max(
|
peak_mem_overall = max(
|
||||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||||
)
|
)
|
||||||
@ -148,3 +151,7 @@ if __name__ == "__main__":
|
|||||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
||||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user