mlx-examples/video/Wan2.2/generate.py

388 lines
13 KiB
Python
Raw Normal View History

2025-07-31 17:30:20 +08:00
# MLX implementation of generate.py
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import mlx.core as mx
import mlx.nn as nn
from PIL import Image
import numpy as np
# Note: MLX doesn't have built-in distributed training support like PyTorch
# For distributed training, you would need to implement custom logic or use MPI
import wan # Assuming wan has been converted to MLX
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.utils import save_video, str2bool
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
args.image = EXAMPLE_PROMPT[args.task]["image"]
if args.task == "i2v-A14B":
assert args.image is not None, "Please specify the image path for i2v."
cfg = WAN_CONFIGS[args.task]
if args.sample_steps is None:
args.sample_steps = cfg.sample_steps
if args.sample_shift is None:
args.sample_shift = cfg.sample_shift
if args.sample_guide_scale is None:
args.sample_guide_scale = cfg.sample_guide_scale
if args.frame_num is None:
args.frame_num = cfg.frame_num
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-A14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames of video are generated. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5. (Note: MLX doesn't have built-in FSDP)")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU. (Note: MLX runs on unified memory)")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT. (Note: MLX doesn't have built-in FSDP)")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="zh",
choices=["zh", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=None,
help="Classifier free guidance scale.")
parser.add_argument(
"--convert_model_dtype",
action="store_true",
default=False,
help="Whether to convert model parameters dtype.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
# MLX doesn't have built-in distributed training like PyTorch
# For single-device execution, we'll simulate rank 0
rank = 0
world_size = 1
local_rank = 0
# Check for distributed execution environment variables
# Note: Actual distributed implementation would require custom logic
if "RANK" in os.environ:
logging.warning("MLX doesn't have built-in distributed training. Running on single device.")
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
# MLX doesn't support FSDP or distributed training out of the box
if args.t5_fsdp or args.dit_fsdp:
logging.warning("FSDP is not supported in MLX. Ignoring FSDP flags.")
args.t5_fsdp = False
args.dit_fsdp = False
if args.ulysses_size > 1:
logging.warning("Sequence parallel is not supported in MLX single-device mode. Setting ulysses_size to 1.")
args.ulysses_size = 1
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
task=args.task,
is_vl=args.image is not None,
device="mlx") # MLX uses unified memory
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
logging.info(f"Input prompt: {args.prompt}")
img = None
if args.image is not None:
img = Image.open(args.image).convert("RGB")
logging.info(f"Input image: {args.image}")
# prompt extend
if args.use_prompt_extend:
logging.info("Extending prompt ...")
prompt_output = prompt_expander(
args.prompt,
image=img,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
args.prompt = input_prompt
logging.info(f"Extended prompt: {args.prompt}")
if "t2v" in args.task:
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elif "ti2v" in args.task:
logging.info("Creating WanTI2V pipeline.")
wan_ti2v = wan.WanTI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=None, # MLX uses unified memory
rank=rank,
t5_fsdp=False, # Not supported in MLX
dit_fsdp=False, # Not supported in MLX
use_sp=False, # Not supported in MLX
t5_cpu=False, # MLX uses unified memory
convert_model_dtype=args.convert_model_dtype,
)
logging.info(f"Generating video ...")
video = wan_ti2v.generate(
args.prompt,
img=img,
size=SIZE_CONFIGS[args.size],
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=None, # MLX uses unified memory
rank=rank,
t5_fsdp=False, # Not supported in MLX
dit_fsdp=False, # Not supported in MLX
use_sp=False, # Not supported in MLX
t5_cpu=False, # MLX uses unified memory
convert_model_dtype=args.convert_model_dtype,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.mp4'
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
logging.info(f"Saving generated video to {args.save_file}")
# Don't convert to numpy - keep as MLX array
save_video(
tensor=video[None], # Just add batch dimension, keep as MLX array
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1)
)
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)