mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add timeout to generate functions
Add timeout handling to various `generate` functions across multiple files. * **cvae/main.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **flux/dreambooth.py** - Add `timeout` parameter to `generate_progress_images` function. - Implement timeout handling using `signal` module in `generate_progress_images` function. * **musicgen/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **stable_diffusion/txt2image.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llava/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. * **llms/gguf_llm/generate.py** - Add `timeout` parameter to `generate` function. - Implement timeout handling using `signal` module in `generate` function. * **llms/mlx_lm/generate.py** - Add `timeout` parameter to `main` function. - Implement timeout handling using `signal` module in `main` function. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/jincdream/mlx-examples?shareId=XXXX-XXXX-XXXX-XXXX).
This commit is contained in:
parent
743763bc2e
commit
f7bbe458ae
27
cvae/main.py
27
cvae/main.py
@ -4,6 +4,7 @@ import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import signal
|
||||
|
||||
import dataset
|
||||
import mlx.core as mx
|
||||
@ -67,16 +68,28 @@ def generate(
|
||||
model,
|
||||
out_file,
|
||||
num_samples=128,
|
||||
timeout=None,
|
||||
):
|
||||
# Sample from the latent distribution:
|
||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
# Decode the latent vectors to images:
|
||||
images = model.decode(z)
|
||||
if timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
# Save all images in a single file
|
||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||
grid_image.save(out_file)
|
||||
try:
|
||||
# Sample from the latent distribution:
|
||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||
|
||||
# Decode the latent vectors to images:
|
||||
images = model.decode(z)
|
||||
|
||||
# Save all images in a single file
|
||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||
grid_image.save(out_file)
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -4,6 +4,7 @@ import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -16,31 +17,42 @@ from PIL import Image
|
||||
from flux import FluxPipeline, Trainer, load_dataset
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
def generate_progress_images(iteration, flux, args, timeout=None):
|
||||
"""Generate images to monitor the progress of the finetuning."""
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
# Generate some images and arrange them in a grid
|
||||
n_rows = 2
|
||||
n_images = 4
|
||||
x = flux.generate_images(
|
||||
args.progress_prompt,
|
||||
n_images,
|
||||
args.progress_steps,
|
||||
)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
if timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(out_file)
|
||||
try:
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
|
||||
# Generate some images and arrange them in a grid
|
||||
n_rows = 2
|
||||
n_images = 4
|
||||
x = flux.generate_images(
|
||||
args.progress_prompt,
|
||||
n_images,
|
||||
args.progress_steps,
|
||||
)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(out_file)
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def save_adapters(iteration, flux, args):
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import codecs
|
||||
from pathlib import Path
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
import requests
|
||||
@ -49,6 +50,12 @@ def parse_arguments():
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Timeout in seconds for the generation process.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -119,21 +126,32 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
tokenizer_config = {}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
if args.timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(args.timeout)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
try:
|
||||
tokenizer_config = {}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
||||
)
|
||||
print(generated_text)
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
input_ids, pixel_values, model, processor, args.max_tokens, args.temp
|
||||
)
|
||||
print(generated_text)
|
||||
finally:
|
||||
if args.timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
import models
|
||||
@ -13,37 +14,49 @@ def generate(
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
temp: float = 0.0,
|
||||
timeout: int = None,
|
||||
):
|
||||
prompt = tokenizer.encode(prompt)
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
if timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
try:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
tic = time.time()
|
||||
tokens = []
|
||||
skip = 0
|
||||
for token, n in zip(
|
||||
models.generate(prompt, model, args.temp),
|
||||
range(args.max_tokens),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if n == 0:
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
tokens.append(token.item())
|
||||
s = tokenizer.decode(tokens)
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
print(tokenizer.decode(tokens)[skip:], flush=True)
|
||||
gen_time = time.time() - tic
|
||||
print("=" * 10)
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -83,4 +96,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
mx.random.seed(args.seed)
|
||||
model, tokenizer = models.load(args.gguf, args.repo)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)
|
||||
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp, timeout=args.timeout)
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@ -107,6 +108,12 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Timeout in seconds for the generation process.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -146,90 +153,101 @@ def main():
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
)
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
if args.timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(args.timeout)
|
||||
|
||||
model_path = args.model
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
model_path = model_path or DEFAULT_MODEL
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||
}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
try:
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
if args.colorize and not args.verbose:
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
model_path = args.model
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
model_path = model_path or DEFAULT_MODEL
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||
}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if using_cache:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
if args.colorize and not args.verbose:
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
finally:
|
||||
if args.timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import time
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
from decoder import SpeculativeDecoder
|
||||
@ -21,27 +22,38 @@ def load_model(model_name: str):
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
spec_decoder = SpeculativeDecoder(
|
||||
model=load_model(args.model_name),
|
||||
draft_model=load_model(args.draft_model_name),
|
||||
tokenizer=args.model_name,
|
||||
delta=args.delta,
|
||||
num_draft=args.num_draft,
|
||||
)
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
tic = time.time()
|
||||
print(args.prompt)
|
||||
if args.regular_decode:
|
||||
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
|
||||
else:
|
||||
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
|
||||
if args.timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(args.timeout)
|
||||
|
||||
try:
|
||||
spec_decoder = SpeculativeDecoder(
|
||||
model=load_model(args.model_name),
|
||||
draft_model=load_model(args.draft_model_name),
|
||||
tokenizer=args.model_name,
|
||||
delta=args.delta,
|
||||
num_draft=args.num_draft,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
print(args.prompt)
|
||||
if args.regular_decode:
|
||||
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
|
||||
else:
|
||||
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
|
||||
print("=" * 10)
|
||||
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
|
||||
print(f"Decoding steps {stats['n_steps']}.")
|
||||
|
||||
toc = time.time()
|
||||
print("=" * 10)
|
||||
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
|
||||
print(f"Decoding steps {stats['n_steps']}.")
|
||||
|
||||
toc = time.time()
|
||||
print("=" * 10)
|
||||
print(f"Full generation time {toc - tic:.3f}")
|
||||
print(f"Full generation time {toc - tic:.3f}")
|
||||
finally:
|
||||
if args.timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -91,5 +103,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Use regular decoding instead of speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Timeout in seconds for the generation process.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -1,16 +1,28 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
|
||||
from utils import save_audio
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
|
||||
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
||||
model = MusicGen.from_pretrained(model_name)
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
save_audio(output_path, audio, model.sampling_rate)
|
||||
def main(text: str, output_path: str, model_name: str, max_steps: int, timeout: int = None):
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
if timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
|
||||
try:
|
||||
model = MusicGen.from_pretrained(model_name)
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
save_audio(output_path, audio, model.sampling_rate)
|
||||
finally:
|
||||
if timeout:
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -19,5 +31,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--text", required=False, default="happy rock")
|
||||
parser.add_argument("--output-path", required=False, default="0.wav")
|
||||
parser.add_argument("--max-steps", required=False, default=500, type=int)
|
||||
parser.add_argument("--timeout", required=False, default=None, type=int)
|
||||
args = parser.parse_args()
|
||||
main(args.text, args.output_path, args.model, args.max_steps)
|
||||
main(args.text, args.output_path, args.model, args.max_steps, args.timeout)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -27,82 +28,97 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--preload-models", action="store_true")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--timeout", type=int, default=None)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
nn.quantize(
|
||||
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Generation timed out")
|
||||
|
||||
if args.timeout:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(args.timeout)
|
||||
|
||||
try:
|
||||
# Load the models
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
nn.quantize(
|
||||
sd.text_encoder_1,
|
||||
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||
)
|
||||
nn.quantize(
|
||||
sd.text_encoder_2,
|
||||
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 0.0
|
||||
args.steps = args.steps or 2
|
||||
else:
|
||||
sd = StableDiffusion(
|
||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||
)
|
||||
nn.quantize(
|
||||
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 0.0
|
||||
args.steps = args.steps or 2
|
||||
else:
|
||||
sd = StableDiffusion(
|
||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||
if args.quantize:
|
||||
nn.quantize(
|
||||
sd.text_encoder,
|
||||
class_predicate=lambda _, m: isinstance(m, nn.Linear),
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
|
||||
# Ensure that models are read in memory if needed
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
cfg_weight=args.cfg,
|
||||
num_steps=args.steps,
|
||||
seed=args.seed,
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
if args.quantize:
|
||||
nn.quantize(
|
||||
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.eval(x_t)
|
||||
|
||||
# Ensure that models are read in memory if needed
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
# The following is not necessary but it may help in memory
|
||||
# constrained systems by reusing the memory kept by the unet and the text
|
||||
# encoders.
|
||||
if args.model == "sdxl":
|
||||
del sd.text_encoder_1
|
||||
del sd.text_encoder_2
|
||||
else:
|
||||
del sd.text_encoder
|
||||
del sd.unet
|
||||
del sd.sampler
|
||||
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
cfg_weight=args.cfg,
|
||||
num_steps=args.steps,
|
||||
seed=args.seed,
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.eval(x_t)
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||
|
||||
# The following is not necessary but it may help in memory
|
||||
# constrained systems by reusing the memory kept by the unet and the text
|
||||
# encoders.
|
||||
if args.model == "sdxl":
|
||||
del sd.text_encoder_1
|
||||
del sd.text_encoder_2
|
||||
else:
|
||||
del sd.text_encoder
|
||||
del sd.unet
|
||||
del sd.sampler
|
||||
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||
finally:
|
||||
if args.timeout:
|
||||
signal.alarm(0)
|
||||
|
Loading…
Reference in New Issue
Block a user