feat: stable-diffusion t2i add --seed (#558)

This commit is contained in:
zweifisch 2024-03-10 21:12:54 +08:00 committed by GitHub
parent ad3cf5ed98
commit d0fa6cfcae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 3 deletions

View File

@ -110,7 +110,7 @@ class StableDiffusion:
seed=None, seed=None,
): ):
# Set the PRNG state # Set the PRNG state
seed = seed or int(time.time()) seed = int(time.time()) if seed is None else seed
mx.random.seed(seed) mx.random.seed(seed)
# Get the text conditioning # Get the text conditioning
@ -140,7 +140,7 @@ class StableDiffusion:
seed=None, seed=None,
): ):
# Set the PRNG state # Set the PRNG state
seed = seed or int(time.time()) seed = int(time.time()) if seed is None else seed
mx.random.seed(seed) mx.random.seed(seed)
# Define the num steps and start step # Define the num steps and start step
@ -238,7 +238,7 @@ class StableDiffusionXL(StableDiffusion):
seed=None, seed=None,
): ):
# Set the PRNG state # Set the PRNG state
seed = seed or int(time.time()) seed = int(time.time()) if seed is None else seed
mx.random.seed(seed) mx.random.seed(seed)
# Get the text conditioning # Get the text conditioning

View File

@ -26,6 +26,7 @@ if __name__ == "__main__":
parser.add_argument("--quantize", "-q", action="store_true") parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true") parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png") parser.add_argument("--output", default="out.png")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@ -55,6 +56,7 @@ if __name__ == "__main__":
n_images=args.n_images, n_images=args.n_images,
cfg_weight=args.cfg, cfg_weight=args.cfg,
num_steps=args.steps, num_steps=args.steps,
seed=args.seed,
negative_text=args.negative_prompt, negative_text=args.negative_prompt,
) )
for x_t in tqdm(latents, total=args.steps): for x_t in tqdm(latents, total=args.steps):