mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
39
video/Wan2.2/wan/configs/__init__.py
Normal file
39
video/Wan2.2/wan/configs/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
from .wan_i2v_A14B import i2v_A14B
|
||||
from .wan_t2v_A14B import t2v_A14B
|
||||
from .wan_ti2v_5B import ti2v_5B
|
||||
|
||||
WAN_CONFIGS = {
|
||||
't2v-A14B': t2v_A14B,
|
||||
'i2v-A14B': i2v_A14B,
|
||||
'ti2v-5B': ti2v_5B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
'720*1280': (720, 1280),
|
||||
'1280*720': (1280, 720),
|
||||
'480*832': (480, 832),
|
||||
'832*480': (832, 480),
|
||||
'704*1280': (704, 1280),
|
||||
'1280*704': (1280, 704)
|
||||
}
|
||||
|
||||
MAX_AREA_CONFIGS = {
|
||||
'720*1280': 720 * 1280,
|
||||
'1280*720': 1280 * 720,
|
||||
'480*832': 480 * 832,
|
||||
'832*480': 832 * 480,
|
||||
'704*1280': 704 * 1280,
|
||||
'1280*704': 1280 * 704,
|
||||
}
|
||||
|
||||
SUPPORTED_SIZES = {
|
||||
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'ti2v-5B': ('704*1280', '1280*704'),
|
||||
}
|
||||
20
video/Wan2.2/wan/configs/shared_config.py
Normal file
20
video/Wan2.2/wan/configs/shared_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
#------------------------ Wan shared config ------------------------#
|
||||
wan_shared_cfg = EasyDict()
|
||||
|
||||
# t5
|
||||
wan_shared_cfg.t5_model = 'umt5_xxl'
|
||||
wan_shared_cfg.t5_dtype = torch.bfloat16
|
||||
wan_shared_cfg.text_len = 512
|
||||
|
||||
# transformer
|
||||
wan_shared_cfg.param_dtype = torch.bfloat16
|
||||
|
||||
# inference
|
||||
wan_shared_cfg.num_train_timesteps = 1000
|
||||
wan_shared_cfg.sample_fps = 16
|
||||
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
||||
wan_shared_cfg.frame_num = 81
|
||||
37
video/Wan2.2/wan/configs/wan_i2v_A14B.py
Normal file
37
video/Wan2.2/wan/configs/wan_i2v_A14B.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan I2V A14B ------------------------#
|
||||
|
||||
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
|
||||
i2v_A14B.update(wan_shared_cfg)
|
||||
|
||||
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
||||
i2v_A14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
i2v_A14B.patch_size = (1, 2, 2)
|
||||
i2v_A14B.dim = 5120
|
||||
i2v_A14B.ffn_dim = 13824
|
||||
i2v_A14B.freq_dim = 256
|
||||
i2v_A14B.num_heads = 40
|
||||
i2v_A14B.num_layers = 40
|
||||
i2v_A14B.window_size = (-1, -1)
|
||||
i2v_A14B.qk_norm = True
|
||||
i2v_A14B.cross_attn_norm = True
|
||||
i2v_A14B.eps = 1e-6
|
||||
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
||||
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
||||
|
||||
# inference
|
||||
i2v_A14B.sample_shift = 5.0
|
||||
i2v_A14B.sample_steps = 40
|
||||
i2v_A14B.boundary = 0.900
|
||||
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
|
||||
37
video/Wan2.2/wan/configs/wan_t2v_A14B.py
Normal file
37
video/Wan2.2/wan/configs/wan_t2v_A14B.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan T2V A14B ------------------------#
|
||||
|
||||
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
|
||||
t2v_A14B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
|
||||
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.safetensors'
|
||||
t2v_A14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
t2v_A14B.patch_size = (1, 2, 2)
|
||||
t2v_A14B.dim = 5120
|
||||
t2v_A14B.ffn_dim = 13824
|
||||
t2v_A14B.freq_dim = 256
|
||||
t2v_A14B.num_heads = 40
|
||||
t2v_A14B.num_layers = 40
|
||||
t2v_A14B.window_size = (-1, -1)
|
||||
t2v_A14B.qk_norm = True
|
||||
t2v_A14B.cross_attn_norm = True
|
||||
t2v_A14B.eps = 1e-6
|
||||
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
||||
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
||||
|
||||
# inference
|
||||
t2v_A14B.sample_shift = 12.0
|
||||
t2v_A14B.sample_steps = 40
|
||||
t2v_A14B.boundary = 0.875
|
||||
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
|
||||
36
video/Wan2.2/wan/configs/wan_ti2v_5B.py
Normal file
36
video/Wan2.2/wan/configs/wan_ti2v_5B.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan TI2V 5B ------------------------#
|
||||
|
||||
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
|
||||
ti2v_5B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
|
||||
ti2v_5B.vae_stride = (4, 16, 16)
|
||||
|
||||
# transformer
|
||||
ti2v_5B.patch_size = (1, 2, 2)
|
||||
ti2v_5B.dim = 3072
|
||||
ti2v_5B.ffn_dim = 14336
|
||||
ti2v_5B.freq_dim = 256
|
||||
ti2v_5B.num_heads = 24
|
||||
ti2v_5B.num_layers = 30
|
||||
ti2v_5B.window_size = (-1, -1)
|
||||
ti2v_5B.qk_norm = True
|
||||
ti2v_5B.cross_attn_norm = True
|
||||
ti2v_5B.eps = 1e-6
|
||||
|
||||
# inference
|
||||
ti2v_5B.sample_fps = 24
|
||||
ti2v_5B.sample_shift = 5.0
|
||||
ti2v_5B.sample_steps = 50
|
||||
ti2v_5B.sample_guide_scale = 5.0
|
||||
ti2v_5B.frame_num = 121
|
||||
Reference in New Issue
Block a user