mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
feat: stable-diffusion t2i add --seed (#558)
This commit is contained in:
parent
ad3cf5ed98
commit
d0fa6cfcae
@ -110,7 +110,7 @@ class StableDiffusion:
|
|||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
# Set the PRNG state
|
# Set the PRNG state
|
||||||
seed = seed or int(time.time())
|
seed = int(time.time()) if seed is None else seed
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Get the text conditioning
|
# Get the text conditioning
|
||||||
@ -140,7 +140,7 @@ class StableDiffusion:
|
|||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
# Set the PRNG state
|
# Set the PRNG state
|
||||||
seed = seed or int(time.time())
|
seed = int(time.time()) if seed is None else seed
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Define the num steps and start step
|
# Define the num steps and start step
|
||||||
@ -238,7 +238,7 @@ class StableDiffusionXL(StableDiffusion):
|
|||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
# Set the PRNG state
|
# Set the PRNG state
|
||||||
seed = seed or int(time.time())
|
seed = int(time.time()) if seed is None else seed
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Get the text conditioning
|
# Get the text conditioning
|
||||||
|
@ -26,6 +26,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--quantize", "-q", action="store_true")
|
parser.add_argument("--quantize", "-q", action="store_true")
|
||||||
parser.add_argument("--preload-models", action="store_true")
|
parser.add_argument("--preload-models", action="store_true")
|
||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
|
parser.add_argument("--seed", type=int)
|
||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ if __name__ == "__main__":
|
|||||||
n_images=args.n_images,
|
n_images=args.n_images,
|
||||||
cfg_weight=args.cfg,
|
cfg_weight=args.cfg,
|
||||||
num_steps=args.steps,
|
num_steps=args.steps,
|
||||||
|
seed=args.seed,
|
||||||
negative_text=args.negative_prompt,
|
negative_text=args.negative_prompt,
|
||||||
)
|
)
|
||||||
for x_t in tqdm(latents, total=args.steps):
|
for x_t in tqdm(latents, total=args.steps):
|
||||||
|
Loading…
Reference in New Issue
Block a user