mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
388
video/Wan2.2/generate.py
Normal file
388
video/Wan2.2/generate.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# MLX implementation of generate.py
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import random
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Note: MLX doesn't have built-in distributed training support like PyTorch
|
||||
# For distributed training, you would need to implement custom logic or use MPI
|
||||
|
||||
import wan # Assuming wan has been converted to MLX
|
||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
from wan.utils.utils import save_video, str2bool
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-A14B": {
|
||||
"prompt":
|
||||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
"i2v-A14B": {
|
||||
"prompt":
|
||||
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||
"image":
|
||||
"examples/i2v_input.JPG",
|
||||
},
|
||||
"ti2v-5B": {
|
||||
"prompt":
|
||||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _validate_args(args):
|
||||
# Basic check
|
||||
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
||||
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
||||
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
||||
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
|
||||
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
||||
|
||||
if args.task == "i2v-A14B":
|
||||
assert args.image is not None, "Please specify the image path for i2v."
|
||||
|
||||
cfg = WAN_CONFIGS[args.task]
|
||||
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = cfg.sample_steps
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = cfg.sample_shift
|
||||
|
||||
if args.sample_guide_scale is None:
|
||||
args.sample_guide_scale = cfg.sample_guide_scale
|
||||
|
||||
if args.frame_num is None:
|
||||
args.frame_num = cfg.frame_num
|
||||
|
||||
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
||||
0, sys.maxsize)
|
||||
# Size check
|
||||
assert args.size in SUPPORTED_SIZES[
|
||||
args.
|
||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a image or video from a text prompt or image using Wan"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="t2v-A14B",
|
||||
choices=list(WAN_CONFIGS.keys()),
|
||||
help="The task to run.")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="1280*720",
|
||||
choices=list(SIZE_CONFIGS.keys()),
|
||||
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_num",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many frames of video are generated. The number should be 4n+1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--offload_model",
|
||||
type=str2bool,
|
||||
default=None,
|
||||
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ulysses_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size of the ulysses parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--t5_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for T5. (Note: MLX doesn't have built-in FSDP)")
|
||||
parser.add_argument(
|
||||
"--t5_cpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to place T5 model on CPU. (Note: MLX runs on unified memory)")
|
||||
parser.add_argument(
|
||||
"--dit_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT. (Note: MLX doesn't have built-in FSDP)")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated video to.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to generate the video from.")
|
||||
parser.add_argument(
|
||||
"--use_prompt_extend",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use prompt extend.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt extend model to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_target_lang",
|
||||
type=str,
|
||||
default="zh",
|
||||
choices=["zh", "en"],
|
||||
help="The target language of prompt extend.")
|
||||
parser.add_argument(
|
||||
"--base_seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="The seed to use for generating the video.")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The image to generate the video from.")
|
||||
parser.add_argument(
|
||||
"--sample_solver",
|
||||
type=str,
|
||||
default='unipc',
|
||||
choices=['unipc', 'dpm++'],
|
||||
help="The solver used to sample.")
|
||||
parser.add_argument(
|
||||
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
||||
parser.add_argument(
|
||||
"--sample_shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Sampling shift factor for flow matching schedulers.")
|
||||
parser.add_argument(
|
||||
"--sample_guide_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Classifier free guidance scale.")
|
||||
parser.add_argument(
|
||||
"--convert_model_dtype",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to convert model parameters dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _init_logging(rank):
|
||||
# logging
|
||||
if rank == 0:
|
||||
# set format
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||
else:
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
def generate(args):
|
||||
# MLX doesn't have built-in distributed training like PyTorch
|
||||
# For single-device execution, we'll simulate rank 0
|
||||
rank = 0
|
||||
world_size = 1
|
||||
local_rank = 0
|
||||
|
||||
# Check for distributed execution environment variables
|
||||
# Note: Actual distributed implementation would require custom logic
|
||||
if "RANK" in os.environ:
|
||||
logging.warning("MLX doesn't have built-in distributed training. Running on single device.")
|
||||
|
||||
_init_logging(rank)
|
||||
|
||||
if args.offload_model is None:
|
||||
args.offload_model = False
|
||||
logging.info(
|
||||
f"offload_model is not specified, set to {args.offload_model}.")
|
||||
|
||||
# MLX doesn't support FSDP or distributed training out of the box
|
||||
if args.t5_fsdp or args.dit_fsdp:
|
||||
logging.warning("FSDP is not supported in MLX. Ignoring FSDP flags.")
|
||||
args.t5_fsdp = False
|
||||
args.dit_fsdp = False
|
||||
|
||||
if args.ulysses_size > 1:
|
||||
logging.warning("Sequence parallel is not supported in MLX single-device mode. Setting ulysses_size to 1.")
|
||||
args.ulysses_size = 1
|
||||
|
||||
if args.use_prompt_extend:
|
||||
if args.prompt_extend_method == "dashscope":
|
||||
prompt_expander = DashScopePromptExpander(
|
||||
model_name=args.prompt_extend_model,
|
||||
task=args.task,
|
||||
is_vl=args.image is not None)
|
||||
elif args.prompt_extend_method == "local_qwen":
|
||||
prompt_expander = QwenPromptExpander(
|
||||
model_name=args.prompt_extend_model,
|
||||
task=args.task,
|
||||
is_vl=args.image is not None,
|
||||
device="mlx") # MLX uses unified memory
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
||||
|
||||
cfg = WAN_CONFIGS[args.task]
|
||||
|
||||
logging.info(f"Generation job args: {args}")
|
||||
logging.info(f"Generation model config: {cfg}")
|
||||
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
img = None
|
||||
if args.image is not None:
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
logging.info(f"Input image: {args.image}")
|
||||
|
||||
# prompt extend
|
||||
if args.use_prompt_extend:
|
||||
logging.info("Extending prompt ...")
|
||||
prompt_output = prompt_expander(
|
||||
args.prompt,
|
||||
image=img,
|
||||
tar_lang=args.prompt_extend_target_lang,
|
||||
seed=args.base_seed)
|
||||
if prompt_output.status == False:
|
||||
logging.info(
|
||||
f"Extending prompt failed: {prompt_output.message}")
|
||||
logging.info("Falling back to original prompt.")
|
||||
input_prompt = args.prompt
|
||||
else:
|
||||
input_prompt = prompt_output.prompt
|
||||
args.prompt = input_prompt
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
if "t2v" in args.task:
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
convert_model_dtype=args.convert_model_dtype,
|
||||
)
|
||||
|
||||
logging.info(f"Generating video ...")
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
elif "ti2v" in args.task:
|
||||
logging.info("Creating WanTI2V pipeline.")
|
||||
wan_ti2v = wan.WanTI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=None, # MLX uses unified memory
|
||||
rank=rank,
|
||||
t5_fsdp=False, # Not supported in MLX
|
||||
dit_fsdp=False, # Not supported in MLX
|
||||
use_sp=False, # Not supported in MLX
|
||||
t5_cpu=False, # MLX uses unified memory
|
||||
convert_model_dtype=args.convert_model_dtype,
|
||||
)
|
||||
|
||||
logging.info(f"Generating video ...")
|
||||
video = wan_ti2v.generate(
|
||||
args.prompt,
|
||||
img=img,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=None, # MLX uses unified memory
|
||||
rank=rank,
|
||||
t5_fsdp=False, # Not supported in MLX
|
||||
dit_fsdp=False, # Not supported in MLX
|
||||
use_sp=False, # Not supported in MLX
|
||||
t5_cpu=False, # MLX uses unified memory
|
||||
convert_model_dtype=args.convert_model_dtype,
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
||||
"_")[:50]
|
||||
suffix = '.mp4'
|
||||
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||
|
||||
logging.info(f"Saving generated video to {args.save_file}")
|
||||
|
||||
# Don't convert to numpy - keep as MLX array
|
||||
save_video(
|
||||
tensor=video[None], # Just add batch dimension, keep as MLX array
|
||||
save_file=args.save_file,
|
||||
fps=cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)
|
||||
)
|
||||
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
generate(args)
|
||||
Reference in New Issue
Block a user