2025-07-29 06:51:11 +08:00
import argparse
from datetime import datetime
import logging
import os
import sys
import random
import mlx . core as mx
from PIL import Image
import wan
from wan . configs import WAN_CONFIGS , SIZE_CONFIGS , MAX_AREA_CONFIGS , SUPPORTED_SIZES
2025-07-29 08:07:26 +08:00
from wan . utils . utils import cache_video , cache_image , str2bool
2025-07-29 06:51:11 +08:00
EXAMPLE_PROMPT = {
" t2v-1.3B " : {
" prompt " : " Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. " ,
} ,
" t2v-14B " : {
" prompt " : " Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. " ,
} ,
" t2i-14B " : {
" prompt " : " 一个朴素端庄的美人 " ,
} ,
" i2v-14B " : {
" prompt " :
" Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline ' s intricate details and the refreshing atmosphere of the seaside. " ,
" image " :
" examples/i2v_input.JPG " ,
} ,
}
def _validate_args ( args ) :
# Basic check
assert args . ckpt_dir is not None , " Please specify the checkpoint directory. "
assert args . task in WAN_CONFIGS , f " Unsupport task: { args . task } "
assert args . task in EXAMPLE_PROMPT , f " Unsupport task: { args . task } "
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args . sample_steps is None :
args . sample_steps = 40 if " i2v " in args . task else 50
if args . sample_shift is None :
args . sample_shift = 5.0
if " i2v " in args . task and args . size in [ " 832*480 " , " 480*832 " ] :
args . sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args . frame_num is None :
args . frame_num = 1 if " t2i " in args . task else 1
# T2I frame_num check
if " t2i " in args . task :
assert args . frame_num == 1 , f " Unsupport frame_num { args . frame_num } for task { args . task } "
args . base_seed = args . base_seed if args . base_seed > = 0 else random . randint (
0 , sys . maxsize )
# Size check
assert args . size in SUPPORTED_SIZES [
args .
task ] , f " Unsupport size { args . size } for task { args . task } , supported sizes are: { ' , ' . join ( SUPPORTED_SIZES [ args . task ] ) } "
def _parse_args ( ) :
parser = argparse . ArgumentParser (
description = " Generate a image or video from a text prompt or image using Wan "
)
parser . add_argument (
" --task " ,
type = str ,
default = " t2v-14B " ,
choices = list ( WAN_CONFIGS . keys ( ) ) ,
help = " The task to run. " )
parser . add_argument (
" --size " ,
type = str ,
default = " 1280*720 " ,
choices = list ( SIZE_CONFIGS . keys ( ) ) ,
help = " The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image. "
)
parser . add_argument (
" --frame_num " ,
type = int ,
default = None ,
help = " How many frames to sample from a image or video. The number should be 4n+1 "
)
parser . add_argument (
" --ckpt_dir " ,
type = str ,
default = None ,
help = " The path to the checkpoint directory. " )
parser . add_argument (
" --offload_model " ,
type = str2bool ,
default = None ,
help = " Whether to offload the model to CPU after each model forward, reducing GPU memory usage. "
)
parser . add_argument (
" --save_file " ,
type = str ,
default = None ,
help = " The file to save the generated image or video to. " )
parser . add_argument (
" --prompt " ,
type = str ,
default = None ,
help = " The prompt to generate the image or video from. " )
parser . add_argument (
" --base_seed " ,
type = int ,
default = - 1 ,
help = " The seed to use for generating the image or video. " )
parser . add_argument (
" --image " ,
type = str ,
default = None ,
help = " The image to generate the video from. " )
parser . add_argument (
" --sample_solver " ,
type = str ,
default = ' unipc ' ,
choices = [ ' unipc ' , ' dpm++ ' ] ,
help = " The solver used to sample. " )
parser . add_argument (
" --sample_steps " , type = int , default = None , help = " The sampling steps. " )
parser . add_argument (
" --sample_shift " ,
type = float ,
default = None ,
help = " Sampling shift factor for flow matching schedulers. " )
parser . add_argument (
" --sample_guide_scale " ,
type = float ,
default = 5.0 ,
help = " Classifier free guidance scale. " )
args = parser . parse_args ( )
_validate_args ( args )
return args
def _init_logging ( ) :
# logging
logging . basicConfig (
level = logging . INFO ,
format = " [ %(asctime)s ] %(levelname)s : %(message)s " ,
handlers = [ logging . StreamHandler ( stream = sys . stdout ) ] )
def generate ( args ) :
_init_logging ( )
# MLX uses default device automatically
if args . offload_model is None :
args . offload_model = True # Default to True to save memory
logging . info (
f " offload_model is not specified, set to { args . offload_model } . " )
cfg = WAN_CONFIGS [ args . task ]
logging . info ( f " Generation job args: { args } " )
logging . info ( f " Generation model config: { cfg } " )
if " t2v " in args . task or " t2i " in args . task :
if args . prompt is None :
args . prompt = EXAMPLE_PROMPT [ args . task ] [ " prompt " ]
logging . info ( f " Input prompt: { args . prompt } " )
logging . info ( " Creating WanT2V pipeline. " )
wan_t2v = wan . WanT2V (
config = cfg ,
checkpoint_dir = args . ckpt_dir ,
)
logging . info (
f " Generating { ' image ' if ' t2i ' in args . task else ' video ' } ... " )
video = wan_t2v . generate (
args . prompt ,
size = SIZE_CONFIGS [ args . size ] ,
frame_num = args . frame_num ,
shift = args . sample_shift ,
sample_solver = args . sample_solver ,
sampling_steps = args . sample_steps ,
guide_scale = args . sample_guide_scale ,
seed = args . base_seed ,
offload_model = args . offload_model )
else :
if args . prompt is None :
args . prompt = EXAMPLE_PROMPT [ args . task ] [ " prompt " ]
if args . image is None :
args . image = EXAMPLE_PROMPT [ args . task ] [ " image " ]
logging . info ( f " Input prompt: { args . prompt } " )
logging . info ( f " Input image: { args . image } " )
img = Image . open ( args . image ) . convert ( " RGB " )
logging . info ( " Creating WanI2V pipeline. " )
wan_i2v = wan . WanI2V (
config = cfg ,
checkpoint_dir = args . ckpt_dir ,
)
logging . info ( " Generating video ... " )
video = wan_i2v . generate (
args . prompt ,
img ,
max_area = MAX_AREA_CONFIGS [ args . size ] ,
frame_num = args . frame_num ,
shift = args . sample_shift ,
sample_solver = args . sample_solver ,
sampling_steps = args . sample_steps ,
guide_scale = args . sample_guide_scale ,
seed = args . base_seed ,
offload_model = args . offload_model )
# Save output
if args . save_file is None :
formatted_time = datetime . now ( ) . strftime ( " % Y % m %d _ % H % M % S " )
formatted_prompt = args . prompt . replace ( " " , " _ " ) . replace ( " / " , " _ " ) [ : 50 ]
suffix = ' .png ' if " t2i " in args . task else ' .mp4 '
args . save_file = f " { args . task } _ { args . size } _ { formatted_prompt } _ { formatted_time } " + suffix
if " t2i " in args . task :
logging . info ( f " Saving generated image to { args . save_file } " )
# Note: cache_image might need to handle MLX arrays
cache_image (
tensor = video . squeeze ( 1 ) [ None ] ,
save_file = args . save_file ,
nrow = 1 ,
normalize = True ,
value_range = ( - 1 , 1 ) )
else :
logging . info ( f " Saving generated video to { args . save_file } " )
cache_video (
tensor = video [ None ] ,
save_file = args . save_file ,
fps = cfg . sample_fps ,
nrow = 1 ,
normalize = True ,
value_range = ( - 1 , 1 ) )
logging . info ( " Finished. " )
if __name__ == " __main__ " :
args = _parse_args ( )
generate ( args )