import argparse from datetime import datetime import logging import os import sys import random import mlx.core as mx from PIL import Image import wan from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES from wan.utils.utils import cache_video, cache_image, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2v-14B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2i-14B": { "prompt": "一个朴素端庄的美人", }, "i2v-14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", "image": "examples/i2v_input.JPG", }, } def _validate_args(args): # Basic check assert args.ckpt_dir is not None, "Please specify the checkpoint directory." assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. if args.sample_steps is None: args.sample_steps = 40 if "i2v" in args.task else 50 if args.sample_shift is None: args.sample_shift = 5.0 if "i2v" in args.task and args.size in ["832*480", "480*832"]: args.sample_shift = 3.0 # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. if args.frame_num is None: args.frame_num = 1 if "t2i" in args.task else 1 # T2I frame_num check if "t2i" in args.task: assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( 0, sys.maxsize) # Size check assert args.size in SUPPORTED_SIZES[ args. task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" def _parse_args(): parser = argparse.ArgumentParser( description="Generate a image or video from a text prompt or image using Wan" ) parser.add_argument( "--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") parser.add_argument( "--size", type=str, default="1280*720", choices=list(SIZE_CONFIGS.keys()), help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." ) parser.add_argument( "--frame_num", type=int, default=None, help="How many frames to sample from a image or video. The number should be 4n+1" ) parser.add_argument( "--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--offload_model", type=str2bool, default=None, help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." ) parser.add_argument( "--save_file", type=str, default=None, help="The file to save the generated image or video to.") parser.add_argument( "--prompt", type=str, default=None, help="The prompt to generate the image or video from.") parser.add_argument( "--base_seed", type=int, default=-1, help="The seed to use for generating the image or video.") parser.add_argument( "--image", type=str, default=None, help="The image to generate the video from.") parser.add_argument( "--sample_solver", type=str, default='unipc', choices=['unipc', 'dpm++'], help="The solver used to sample.") parser.add_argument( "--sample_steps", type=int, default=None, help="The sampling steps.") parser.add_argument( "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers.") parser.add_argument( "--sample_guide_scale", type=float, default=5.0, help="Classifier free guidance scale.") args = parser.parse_args() _validate_args(args) return args def _init_logging(): # logging logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]) def generate(args): _init_logging() # MLX uses default device automatically if args.offload_model is None: args.offload_model = True # Default to True to save memory logging.info( f"offload_model is not specified, set to {args.offload_model}.") cfg = WAN_CONFIGS[args.task] logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {cfg}") if "t2v" in args.task or "t2i" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] logging.info(f"Input prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, ) logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") video = wan_t2v.generate( args.prompt, size=SIZE_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) else: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.image is None: args.image = EXAMPLE_PROMPT[args.task]["image"] logging.info(f"Input prompt: {args.prompt}") logging.info(f"Input image: {args.image}") img = Image.open(args.image).convert("RGB") logging.info("Creating WanI2V pipeline.") wan_i2v = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir, ) logging.info("Generating video ...") video = wan_i2v.generate( args.prompt, img, max_area=MAX_AREA_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) # Save output if args.save_file is None: formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] suffix = '.png' if "t2i" in args.task else '.mp4' args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix if "t2i" in args.task: logging.info(f"Saving generated image to {args.save_file}") # Note: cache_image might need to handle MLX arrays cache_image( tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True, value_range=(-1, 1)) else: logging.info(f"Saving generated video to {args.save_file}") cache_video( tensor=video[None], save_file=args.save_file, fps=cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1)) logging.info("Finished.") if __name__ == "__main__": args = _parse_args() generate(args)