mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
2
video/Wan2.2/wan/__init__.py
Normal file
2
video/Wan2.2/wan/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .text2video import WanT2V
|
||||
39
video/Wan2.2/wan/configs/__init__.py
Normal file
39
video/Wan2.2/wan/configs/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
from .wan_i2v_A14B import i2v_A14B
|
||||
from .wan_t2v_A14B import t2v_A14B
|
||||
from .wan_ti2v_5B import ti2v_5B
|
||||
|
||||
WAN_CONFIGS = {
|
||||
't2v-A14B': t2v_A14B,
|
||||
'i2v-A14B': i2v_A14B,
|
||||
'ti2v-5B': ti2v_5B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
'720*1280': (720, 1280),
|
||||
'1280*720': (1280, 720),
|
||||
'480*832': (480, 832),
|
||||
'832*480': (832, 480),
|
||||
'704*1280': (704, 1280),
|
||||
'1280*704': (1280, 704)
|
||||
}
|
||||
|
||||
MAX_AREA_CONFIGS = {
|
||||
'720*1280': 720 * 1280,
|
||||
'1280*720': 1280 * 720,
|
||||
'480*832': 480 * 832,
|
||||
'832*480': 832 * 480,
|
||||
'704*1280': 704 * 1280,
|
||||
'1280*704': 1280 * 704,
|
||||
}
|
||||
|
||||
SUPPORTED_SIZES = {
|
||||
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'ti2v-5B': ('704*1280', '1280*704'),
|
||||
}
|
||||
20
video/Wan2.2/wan/configs/shared_config.py
Normal file
20
video/Wan2.2/wan/configs/shared_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
#------------------------ Wan shared config ------------------------#
|
||||
wan_shared_cfg = EasyDict()
|
||||
|
||||
# t5
|
||||
wan_shared_cfg.t5_model = 'umt5_xxl'
|
||||
wan_shared_cfg.t5_dtype = torch.bfloat16
|
||||
wan_shared_cfg.text_len = 512
|
||||
|
||||
# transformer
|
||||
wan_shared_cfg.param_dtype = torch.bfloat16
|
||||
|
||||
# inference
|
||||
wan_shared_cfg.num_train_timesteps = 1000
|
||||
wan_shared_cfg.sample_fps = 16
|
||||
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
||||
wan_shared_cfg.frame_num = 81
|
||||
37
video/Wan2.2/wan/configs/wan_i2v_A14B.py
Normal file
37
video/Wan2.2/wan/configs/wan_i2v_A14B.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan I2V A14B ------------------------#
|
||||
|
||||
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
|
||||
i2v_A14B.update(wan_shared_cfg)
|
||||
|
||||
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
||||
i2v_A14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
i2v_A14B.patch_size = (1, 2, 2)
|
||||
i2v_A14B.dim = 5120
|
||||
i2v_A14B.ffn_dim = 13824
|
||||
i2v_A14B.freq_dim = 256
|
||||
i2v_A14B.num_heads = 40
|
||||
i2v_A14B.num_layers = 40
|
||||
i2v_A14B.window_size = (-1, -1)
|
||||
i2v_A14B.qk_norm = True
|
||||
i2v_A14B.cross_attn_norm = True
|
||||
i2v_A14B.eps = 1e-6
|
||||
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
||||
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
||||
|
||||
# inference
|
||||
i2v_A14B.sample_shift = 5.0
|
||||
i2v_A14B.sample_steps = 40
|
||||
i2v_A14B.boundary = 0.900
|
||||
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
|
||||
37
video/Wan2.2/wan/configs/wan_t2v_A14B.py
Normal file
37
video/Wan2.2/wan/configs/wan_t2v_A14B.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan T2V A14B ------------------------#
|
||||
|
||||
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
|
||||
t2v_A14B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.safetensors'
|
||||
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.safetensors'
|
||||
t2v_A14B.vae_stride = (4, 8, 8)
|
||||
|
||||
# transformer
|
||||
t2v_A14B.patch_size = (1, 2, 2)
|
||||
t2v_A14B.dim = 5120
|
||||
t2v_A14B.ffn_dim = 13824
|
||||
t2v_A14B.freq_dim = 256
|
||||
t2v_A14B.num_heads = 40
|
||||
t2v_A14B.num_layers = 40
|
||||
t2v_A14B.window_size = (-1, -1)
|
||||
t2v_A14B.qk_norm = True
|
||||
t2v_A14B.cross_attn_norm = True
|
||||
t2v_A14B.eps = 1e-6
|
||||
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
||||
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
||||
|
||||
# inference
|
||||
t2v_A14B.sample_shift = 12.0
|
||||
t2v_A14B.sample_steps = 40
|
||||
t2v_A14B.boundary = 0.875
|
||||
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
|
||||
36
video/Wan2.2/wan/configs/wan_ti2v_5B.py
Normal file
36
video/Wan2.2/wan/configs/wan_ti2v_5B.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from easydict import EasyDict
|
||||
|
||||
from .shared_config import wan_shared_cfg
|
||||
|
||||
#------------------------ Wan TI2V 5B ------------------------#
|
||||
|
||||
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
|
||||
ti2v_5B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
|
||||
ti2v_5B.vae_stride = (4, 16, 16)
|
||||
|
||||
# transformer
|
||||
ti2v_5B.patch_size = (1, 2, 2)
|
||||
ti2v_5B.dim = 3072
|
||||
ti2v_5B.ffn_dim = 14336
|
||||
ti2v_5B.freq_dim = 256
|
||||
ti2v_5B.num_heads = 24
|
||||
ti2v_5B.num_layers = 30
|
||||
ti2v_5B.window_size = (-1, -1)
|
||||
ti2v_5B.qk_norm = True
|
||||
ti2v_5B.cross_attn_norm = True
|
||||
ti2v_5B.eps = 1e-6
|
||||
|
||||
# inference
|
||||
ti2v_5B.sample_fps = 24
|
||||
ti2v_5B.sample_shift = 5.0
|
||||
ti2v_5B.sample_steps = 50
|
||||
ti2v_5B.sample_guide_scale = 5.0
|
||||
ti2v_5B.frame_num = 121
|
||||
17
video/Wan2.2/wan/modules/__init__.py
Normal file
17
video/Wan2.2/wan/modules/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vae2_1 import Wan2_1_VAE
|
||||
|
||||
__all__ = [
|
||||
'Wan2_1_VAE',
|
||||
'Wan2_2_VAE',
|
||||
'WanModel',
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
'HuggingfaceTokenizer',
|
||||
'mlx_attention',
|
||||
]
|
||||
660
video/Wan2.2/wan/modules/model.py
Normal file
660
video/Wan2.2/wan/modules/model.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# MLX implementation of model.py
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.astype(mx.float32)
|
||||
|
||||
# calculation
|
||||
arange_vals = mx.arange(half).astype(mx.float32)
|
||||
div_term = mx.power(10000, -arange_vals / half)
|
||||
sinusoid = position[:, None] @ div_term[None, :]
|
||||
x = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
return x
|
||||
|
||||
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
positions = mx.arange(max_seq_len).astype(mx.float32)
|
||||
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
|
||||
freqs = 1.0 / mx.power(theta, freqs)
|
||||
angles = positions[:, None] @ freqs[None, :]
|
||||
# Store as [max_seq_len, dim//2, 2] where last dimension is [real, imag]
|
||||
freqs_complex = mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1)
|
||||
return freqs_complex
|
||||
|
||||
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.shape[2], x.shape[3] // 2
|
||||
|
||||
# split freqs based on dimension allocation
|
||||
split_sizes = [c - 2 * (c // 3), c // 3, c // 3]
|
||||
freqs_splits = []
|
||||
start = 0
|
||||
for size in split_sizes:
|
||||
freqs_splits.append(freqs[:, start:start+size, :])
|
||||
start += size
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# reshape x_i to complex representation
|
||||
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
|
||||
|
||||
# precompute frequency multipliers for each dimension
|
||||
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
|
||||
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
|
||||
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
|
||||
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
# Concatenate frequency components
|
||||
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
|
||||
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
|
||||
|
||||
# apply rotary embedding (complex multiplication)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
freqs_real = freqs_i[..., 0]
|
||||
freqs_imag = freqs_i[..., 1]
|
||||
|
||||
out_real = x_real * freqs_real - x_imag * freqs_imag
|
||||
out_imag = x_real * freqs_imag + x_imag * freqs_real
|
||||
|
||||
x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)
|
||||
|
||||
# Handle remaining sequence
|
||||
if x.shape[1] > seq_len:
|
||||
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
|
||||
|
||||
output.append(x_i)
|
||||
|
||||
return mx.stack(output)
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, affine=False):
|
||||
super().__init__(dims=dim, eps=eps, affine=affine)
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
"""
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def mlx_attention(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
v: mx.array,
|
||||
q_lens: Optional[mx.array] = None,
|
||||
k_lens: Optional[mx.array] = None,
|
||||
dropout_p: float = 0.,
|
||||
softmax_scale: Optional[float] = None,
|
||||
q_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
deterministic: bool = False,
|
||||
dtype: Optional[type] = None,
|
||||
) -> mx.array:
|
||||
# Get shapes
|
||||
b, lq, n, d = q.shape
|
||||
_, lk, _, _ = k.shape
|
||||
|
||||
# Scale queries if needed
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
# Compute attention scores
|
||||
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
|
||||
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
|
||||
# Compute attention scores
|
||||
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
|
||||
|
||||
# Apply softmax scale if provided
|
||||
if softmax_scale is not None:
|
||||
scores = scores * softmax_scale
|
||||
else:
|
||||
# Default scaling by sqrt(d)
|
||||
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = None
|
||||
|
||||
# Apply window size masking if specified
|
||||
if window_size != (-1, -1):
|
||||
left_window, right_window = window_size
|
||||
window_mask = mx.zeros((lq, lk))
|
||||
for i in range(lq):
|
||||
start = max(0, i - left_window)
|
||||
end = min(lk, i + right_window + 1)
|
||||
window_mask[i, start:end] = 1
|
||||
attn_mask = window_mask
|
||||
|
||||
# Apply causal masking if needed
|
||||
if causal:
|
||||
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
|
||||
if attn_mask is None:
|
||||
attn_mask = causal_mask
|
||||
else:
|
||||
attn_mask = mx.logical_and(attn_mask, causal_mask)
|
||||
|
||||
# Apply attention mask if present
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.astype(scores.dtype)
|
||||
scores = scores * attn_mask + (1 - attn_mask) * -1e4
|
||||
|
||||
# Apply attention mask if lengths are provided
|
||||
if q_lens is not None or k_lens is not None:
|
||||
if q_lens is not None:
|
||||
mask = mx.arange(lq)[None, :] < q_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
|
||||
if k_lens is not None:
|
||||
mask = mx.arange(lk)[None, :] < k_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
|
||||
|
||||
# Apply softmax
|
||||
max_scores = mx.max(scores, axis=-1, keepdims=True)
|
||||
scores = scores - max_scores
|
||||
exp_scores = mx.exp(scores)
|
||||
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
|
||||
attn = exp_scores / (sum_exp + 1e-6)
|
||||
|
||||
# Apply dropout if needed
|
||||
if dropout_p > 0 and not deterministic:
|
||||
raise NotImplementedError("Dropout not implemented in MLX version")
|
||||
|
||||
# Compute output
|
||||
out = mx.matmul(attn, v) # [b, n, lq, d]
|
||||
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
|
||||
|
||||
return out
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def __call__(self, x, seq_lens, grid_sizes, freqs):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
seq_lens(Array): Shape [B]
|
||||
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||||
"""
|
||||
b, s, n, d = x.shape[0], x.shape[1], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
q = self.norm_q(self.q(x)).reshape(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).reshape(b, s, n, d)
|
||||
v = self.v(x).reshape(b, s, n, d)
|
||||
|
||||
x = mlx_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, s, -1)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanCrossAttention(WanSelfAttention):
|
||||
|
||||
def __call__(self, x, context, context_lens):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L1, C]
|
||||
context(Array): Shape [B, L2, C]
|
||||
context_lens(Array): Shape [B]
|
||||
"""
|
||||
b, n, d = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).reshape(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).reshape(b, -1, n, d)
|
||||
v = self.v(context).reshape(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = mlx_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, self.dim)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps)
|
||||
self.norm3 = WanLayerNorm(
|
||||
dim, eps,
|
||||
affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
|
||||
eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(ffn_dim, dim))
|
||||
|
||||
# modulation
|
||||
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
e(Array): Shape [B, L1, 6, C]
|
||||
seq_lens(Array): Shape [B], length of each sequence in batch
|
||||
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||||
"""
|
||||
e = mx.split(self.modulation + e, 6, axis=2)
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2),
|
||||
seq_lens, grid_sizes, freqs)
|
||||
x = x + y * mx.squeeze(e[2], axis=2)
|
||||
|
||||
# cross-attention & ffn function
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(
|
||||
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
|
||||
x = x + y * mx.squeeze(e[5], axis=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
|
||||
|
||||
def __call__(self, x, e):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L1, C]
|
||||
e(Array): Shape [B, L1, C]
|
||||
"""
|
||||
e = mx.split(self.modulation + mx.expand_dims(e, axis=2), 2, axis=2)
|
||||
x = self.head(
|
||||
self.norm(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2))
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6):
|
||||
"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v', 'ti2v']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
self.blocks = [
|
||||
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
|
||||
cross_attn_norm, eps) for _ in range(num_layers)
|
||||
]
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# buffers
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
], axis=1)
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
y=None,
|
||||
):
|
||||
"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Array]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Array):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Array]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
y (List[Array], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Array]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert y is not None
|
||||
|
||||
if y is not None:
|
||||
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(mx.expand_dims(mx.transpose(u, (1, 2, 3, 0)), axis=0)) for u in x]
|
||||
|
||||
grid_sizes = mx.stack(
|
||||
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])
|
||||
|
||||
|
||||
x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]
|
||||
|
||||
seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
|
||||
assert seq_lens.max() <= seq_len
|
||||
|
||||
# Pad sequences
|
||||
x_padded = []
|
||||
for u in x:
|
||||
pad_len = seq_len - u.shape[1]
|
||||
if pad_len > 0:
|
||||
padding = mx.zeros((u.shape[0], pad_len, u.shape[2]))
|
||||
u = mx.concatenate([u, padding], axis=1)
|
||||
x_padded.append(u)
|
||||
x = mx.concatenate(x_padded, axis=0)
|
||||
|
||||
# time embeddings
|
||||
if t.ndim == 1:
|
||||
t = mx.broadcast_to(t[:, None], (t.shape[0], seq_len))
|
||||
bt = t.shape[0]
|
||||
t = t.flatten()
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).reshape(bt, seq_len, -1))
|
||||
e0 = self.time_projection(e).reshape(bt, seq_len, 6, self.dim)
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context_padded = []
|
||||
for u in context:
|
||||
pad_len = self.text_len - u.shape[0]
|
||||
if pad_len > 0:
|
||||
padding = mx.zeros((pad_len, u.shape[1]))
|
||||
u = mx.concatenate([u, padding], axis=0)
|
||||
context_padded.append(u)
|
||||
context = self.text_embedding(mx.stack(context_padded))
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Array]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Array):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Array]:
|
||||
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for i, v in enumerate(grid_sizes):
|
||||
v = v.tolist()
|
||||
seq_len = math.prod(v)
|
||||
u = x[i, :seq_len].reshape(*v, *self.patch_size, c)
|
||||
# Rearrange dimensions: (f, h, w, p, q, r, c) -> (c, f*p, h*q, w*r)
|
||||
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
|
||||
u = u.reshape(c, v[0] * self.patch_size[0],
|
||||
v[1] * self.patch_size[1],
|
||||
v[2] * self.patch_size[2])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize model parameters using Xavier initialization.
|
||||
"""
|
||||
# Initialize patch embedding
|
||||
fan_in = self.in_dim * math.prod(self.patch_size)
|
||||
fan_out = self.dim
|
||||
std = math.sqrt(2.0 / (fan_in + fan_out))
|
||||
self.patch_embedding.weight = mx.random.uniform(
|
||||
low=-std, high=std, shape=self.patch_embedding.weight.shape)
|
||||
|
||||
# Initialize text embedding layers with normal distribution
|
||||
text_layers = list(self.text_embedding.layers)
|
||||
for i in [0, 2]: # First and third layers
|
||||
layer = text_layers[i]
|
||||
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
layer.bias = mx.zeros(layer.bias.shape)
|
||||
|
||||
# Initialize time embedding layers
|
||||
time_layers = list(self.time_embedding.layers)
|
||||
for i in [0, 2]: # First and third layers
|
||||
layer = time_layers[i]
|
||||
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
layer.bias = mx.zeros(layer.bias.shape)
|
||||
|
||||
# Initialize output head to zeros
|
||||
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
|
||||
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
|
||||
self.head.head.bias = mx.zeros(self.head.head.bias.shape)
|
||||
616
video/Wan2.2/wan/modules/t5.py
Normal file
616
video/Wan2.2/wan/modules/t5.py
Normal file
@@ -0,0 +1,616 @@
|
||||
# MLX implementation for t5.py
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False)
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
||||
82
video/Wan2.2/wan/modules/tokenizers.py
Normal file
82
video/Wan2.2/wan/modules/tokenizers.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import html
|
||||
import string
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
__all__ = ['HuggingfaceTokenizer']
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
703
video/Wan2.2/wan/modules/vae2_1.py
Normal file
703
video/Wan2.2/wan/modules/vae2_1.py
Normal file
@@ -0,0 +1,703 @@
|
||||
# MLX implementation of vae2_1.py
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
__all__ = [
|
||||
'Wan2_1_VAE',
|
||||
]
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolution for MLX.
|
||||
Expects input in BTHWC format (batch, time, height, width, channels).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Padding order: (W, W, H, H, T, 0)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
|
||||
padding[4] -= cache_x.shape[1]
|
||||
|
||||
# Pad in BTHWC format
|
||||
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
|
||||
(padding[0], padding[1]), (0, 0)]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
result = super().__call__(x)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=False, images=True, bias=False):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.images = images
|
||||
self.scale = dim**0.5
|
||||
|
||||
# Just keep as 1D - let broadcasting do its magic
|
||||
self.gamma = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,)) if bias else 0.
|
||||
|
||||
def __call__(self, x):
|
||||
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
|
||||
x = x / norm
|
||||
return x * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Upsampling layer that matches PyTorch's behavior.
|
||||
"""
|
||||
def __init__(self, scale_factor, mode='nearest-exact'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode # mode is now unused, but kept for signature consistency
|
||||
|
||||
def __call__(self, x):
|
||||
scale_h, scale_w = self.scale_factor
|
||||
|
||||
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
|
||||
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
|
||||
|
||||
return out
|
||||
|
||||
class AsymmetricPad(nn.Module):
|
||||
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
|
||||
def __init__(self, pad_width: tuple):
|
||||
super().__init__()
|
||||
self.pad_width = pad_width
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(x, self.pad_width)
|
||||
|
||||
# Update your Resample class to use 'nearest-exact'
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
# --- CORRECTED PADDING LOGIC ---
|
||||
elif mode == 'downsample2d':
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
elif mode == 'downsample3d':
|
||||
# The spatial downsampling part uses the same logic
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# The __call__ method logic remains unchanged from your original code
|
||||
b, t, h, w, c = x.shape
|
||||
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
mx.zeros_like(cache_x), cache_x
|
||||
], axis=1)
|
||||
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, t, h, w, 2, c)
|
||||
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
|
||||
x = x.reshape(b, t * 2, h, w, c)
|
||||
|
||||
t = x.shape[1]
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
_, h_new, w_new, c_new = x.shape
|
||||
x = x.reshape(b, t, h_new, w_new, c_new)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -1:, :, :, :]
|
||||
x = self.time_conv(
|
||||
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
|
||||
for i, layer in enumerate(self.residual.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
self.proj.weight = mx.zeros_like(self.proj.weight)
|
||||
|
||||
def __call__(self, x):
|
||||
# x is in BTHWC format
|
||||
identity = x
|
||||
b, t, h, w, c = x.shape
|
||||
x = x.reshape(b * t, h, w, c) # Combine batch and time
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
|
||||
qkv = qkv.reshape(b * t, h * w, 3 * c)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.reshape(b * t, h * w, c)
|
||||
k = k.reshape(b * t, h * w, c)
|
||||
v = v.reshape(b * t, h * w, c)
|
||||
|
||||
# Scaled dot product attention
|
||||
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
|
||||
scores = (q @ k.transpose(0, 2, 1)) * scale
|
||||
weights = mx.softmax(scores, axis=-1)
|
||||
x = weights @ v
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = x.reshape(b, t, h, w, c)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[-1], dims[-1], dropout),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1], dropout)
|
||||
)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples.layers):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout)
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples.layers:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x, scale):
|
||||
# x is in BTHWC format
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[1]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## Split encode input x by time into 1, 4, 4, 4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :1, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
|
||||
z = self.conv1(out)
|
||||
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
|
||||
|
||||
if isinstance(scale[0], mx.array):
|
||||
# Reshape scale for broadcasting in BTHWC format
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
mu = (mu - scale_mean) * scale_std
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z, scale):
|
||||
# z is in BTHWC format
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], mx.array):
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
z = z / scale_std + scale_mean
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[1]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = mx.exp(0.5 * log_var)
|
||||
eps = mx.random.normal(std.shape)
|
||||
return eps * std + mu
|
||||
|
||||
def __call__(self, x):
|
||||
mu, log_var = self.encode(x, self.scale)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z, self.scale)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs, self.scale)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
|
||||
return mu + std * mx.random.normal(std.shape)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
|
||||
# params
|
||||
cfg = dict(
|
||||
dim=96,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if pretrained_path:
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
weights = mx.load(pretrained_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Wan2_1_VAE:
|
||||
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=mx.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = mx.array(mean, dtype=dtype)
|
||||
self.std = mx.array(std, dtype=dtype)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
Returns: List of encoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
encoded = []
|
||||
for video in videos:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(video, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Encode
|
||||
z = self.model.encode(x, self.scale)[0] # Get mu only
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
z = z.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
encoded.append(z.astype(mx.float32))
|
||||
|
||||
return encoded
|
||||
|
||||
def decode(self, zs):
|
||||
"""
|
||||
zs: A list of latent codes each with shape [C, T, H, W].
|
||||
Returns: List of decoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
decoded = []
|
||||
for z in zs:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(z, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Decode
|
||||
x = self.model.decode(x, self.scale)
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
x = x.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
# Clamp values
|
||||
x = mx.clip(x, -1, 1)
|
||||
|
||||
decoded.append(x.astype(mx.float32))
|
||||
|
||||
return decoded
|
||||
265
video/Wan2.2/wan/t5_model_io.py
Normal file
265
video/Wan2.2/wan/t5_model_io.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import json
|
||||
from typing import Optional, List, Tuple
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
|
||||
|
||||
def check_safetensors_dtypes(safetensors_path: str):
|
||||
"""
|
||||
Check what dtypes are in the safetensors file.
|
||||
Useful for debugging dtype issues.
|
||||
"""
|
||||
print(f"🔍 Checking dtypes in: {safetensors_path}")
|
||||
|
||||
dtype_counts = {}
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
dtype_str = str(tensor.dtype)
|
||||
|
||||
if dtype_str not in dtype_counts:
|
||||
dtype_counts[dtype_str] = []
|
||||
dtype_counts[dtype_str].append(key)
|
||||
|
||||
print("📊 Dtype summary:")
|
||||
for dtype, keys in dtype_counts.items():
|
||||
print(f" {dtype}: {len(keys)} parameters")
|
||||
if dtype == "torch.bfloat16":
|
||||
print(f" ⚠️ BFloat16 detected - will convert to float32")
|
||||
print(f" Examples: {keys[:3]}")
|
||||
|
||||
return dtype_counts
|
||||
|
||||
|
||||
def convert_tensor_dtype(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert tensor to MLX-compatible dtype.
|
||||
"""
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# Convert BFloat16 to float32
|
||||
return tensor.float()
|
||||
elif tensor.dtype == torch.float64:
|
||||
# Convert float64 to float32 for efficiency
|
||||
return tensor.float()
|
||||
else:
|
||||
# Keep other dtypes as-is
|
||||
return tensor
|
||||
|
||||
|
||||
def map_t5_encoder_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
|
||||
"""
|
||||
Map T5 encoder weights from PyTorch format to MLX format.
|
||||
Following the pattern used in MLX Stable Diffusion.
|
||||
|
||||
Args:
|
||||
key: Parameter name from PyTorch model
|
||||
value: Parameter tensor
|
||||
|
||||
Returns:
|
||||
List of (key, value) tuples for MLX model
|
||||
"""
|
||||
|
||||
# Handle the main structural difference: FFN gate layer
|
||||
if ".ffn.gate.0.weight" in key:
|
||||
# PyTorch has Sequential(Linear, GELU) but MLX has separate gate_proj + gate_act
|
||||
key = key.replace(".ffn.gate.0.weight", ".ffn.gate_proj.weight")
|
||||
return [(key, value)]
|
||||
|
||||
elif ".ffn.gate.0.bias" in key:
|
||||
# Handle bias if it exists
|
||||
key = key.replace(".ffn.gate.0.bias", ".ffn.gate_proj.bias")
|
||||
return [(key, value)]
|
||||
|
||||
elif ".ffn.gate.1" in key:
|
||||
# Skip GELU activation parameters - MLX handles this separately
|
||||
print(f"Skipping GELU parameter: {key}")
|
||||
return []
|
||||
|
||||
# Handle any other potential FFN mappings
|
||||
elif ".ffn.fc1.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".ffn.fc2.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Handle attention layers (should be direct mapping)
|
||||
elif ".attn.q.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.k.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.v.weight" in key:
|
||||
return [(key, value)]
|
||||
elif ".attn.o.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Handle embeddings and norms (direct mapping)
|
||||
elif "token_embedding.weight" in key:
|
||||
return [(key, value)]
|
||||
elif "pos_embedding.embedding.weight" in key:
|
||||
return [(key, value)]
|
||||
elif "norm1.weight" in key or "norm2.weight" in key or "norm.weight" in key:
|
||||
return [(key, value)]
|
||||
|
||||
# Default: direct mapping for any other parameters
|
||||
else:
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def _flatten(params: List[List[Tuple[str, mx.array]]]) -> List[Tuple[str, mx.array]]:
|
||||
"""Flatten nested list of parameter tuples"""
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def _load_safetensor_weights(
|
||||
mapper_func,
|
||||
model,
|
||||
weight_file: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Load safetensor weights using the mapping function.
|
||||
Based on MLX SD pattern.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
# Load weights from safetensors file
|
||||
weights = {}
|
||||
with safe_open(weight_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16 - convert to float32 first
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
print(f"Converting BFloat16 to float32 for: {key}")
|
||||
tensor = tensor.float() # Convert to float32
|
||||
|
||||
weights[key] = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping function
|
||||
mapped_weights = _flatten([mapper_func(k, v) for k, v in weights.items()])
|
||||
|
||||
# Update model with mapped weights
|
||||
model.update(tree_unflatten(mapped_weights))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder_from_safetensors(
|
||||
safetensors_path: str,
|
||||
model, # Your MLX T5Encoder instance
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Load T5 encoder weights from safetensors file into MLX model.
|
||||
|
||||
Args:
|
||||
safetensors_path: Path to the safetensors file
|
||||
model: Your MLX T5Encoder model instance
|
||||
float16: Whether to use float16 precision
|
||||
|
||||
Returns:
|
||||
Model with loaded weights
|
||||
"""
|
||||
print(f"Loading T5 encoder weights from: {safetensors_path}")
|
||||
|
||||
# Load and map weights
|
||||
model = _load_safetensor_weights(
|
||||
map_t5_encoder_weights,
|
||||
model,
|
||||
safetensors_path,
|
||||
float16
|
||||
)
|
||||
|
||||
print("T5 encoder weights loaded successfully!")
|
||||
return model
|
||||
|
||||
|
||||
def debug_weight_mapping(safetensors_path: str, float16: bool = False):
|
||||
"""
|
||||
Debug function to see how weights are being mapped.
|
||||
Useful for troubleshooting conversion issues.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print("=== T5 Weight Mapping Debug ===")
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
original_dtype = tensor.dtype
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
print(f"Converting BFloat16 to float32 for: {key}")
|
||||
tensor = tensor.float()
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_t5_encoder_weights(key, value)
|
||||
|
||||
if len(mapped) == 0:
|
||||
print(f"SKIPPED: {key} ({original_dtype}) -> (no mapping)")
|
||||
elif len(mapped) == 1:
|
||||
new_key, new_value = mapped[0]
|
||||
if new_key == key:
|
||||
print(f"DIRECT: {key} ({original_dtype}) [{tensor.shape}]")
|
||||
else:
|
||||
print(f"MAPPED: {key} ({original_dtype}) -> {new_key} [{tensor.shape}]")
|
||||
else:
|
||||
print(f"SPLIT: {key} ({original_dtype}) -> {len(mapped)} parameters")
|
||||
for new_key, new_value in mapped:
|
||||
print(f" -> {new_key} [{new_value.shape}]")
|
||||
|
||||
|
||||
def convert_safetensors_to_mlx_weights(
|
||||
safetensors_path: str,
|
||||
output_path: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""
|
||||
Convert safetensors file to MLX weights file (.npz format).
|
||||
|
||||
Args:
|
||||
safetensors_path: Input safetensors file
|
||||
output_path: Output MLX weights file (.npz)
|
||||
float16: Whether to use float16 precision
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print(f"Converting safetensors to MLX format...")
|
||||
print(f"Input: {safetensors_path}")
|
||||
print(f"Output: {output_path}")
|
||||
print(f"Target dtype: {dtype}")
|
||||
|
||||
# Load and convert weights
|
||||
weights = {}
|
||||
bfloat16_count = 0
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
# if tensor.dtype == torch.bfloat16:
|
||||
# bfloat16_count += 1
|
||||
# tensor = tensor.float() # Convert to float32 first
|
||||
|
||||
value = mx.array(tensor.numpy())#.astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_t5_encoder_weights(key, value)
|
||||
|
||||
for new_key, new_value in mapped:
|
||||
weights[new_key] = new_value
|
||||
|
||||
if bfloat16_count > 0:
|
||||
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to float32")
|
||||
|
||||
# Save as MLX format
|
||||
print(f"Saving {len(weights)} parameters to: {output_path}")
|
||||
mx.save_safetensors(output_path, weights)
|
||||
|
||||
return weights
|
||||
233
video/Wan2.2/wan/t5_torch_to_sf.py
Normal file
233
video/Wan2.2/wan/t5_torch_to_sf.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from wan.modules.t5 import T5Model
|
||||
|
||||
|
||||
def convert_pickle_to_safetensors(
|
||||
pickle_path: str,
|
||||
safetensors_path: str,
|
||||
model_class=None,
|
||||
model_kwargs=None,
|
||||
load_method: str = "weights_only" # Changed default to weights_only
|
||||
):
|
||||
"""Convert PyTorch pickle file to safetensors format."""
|
||||
|
||||
print(f"Loading PyTorch weights from: {pickle_path}")
|
||||
|
||||
# Try multiple loading methods in order of preference
|
||||
methods_to_try = [load_method, "weights_only", "state_dict", "full_model"]
|
||||
|
||||
for method in methods_to_try:
|
||||
try:
|
||||
if method == "weights_only":
|
||||
state_dict = torch.load(pickle_path, map_location='cpu', weights_only=True)
|
||||
|
||||
elif method == "state_dict":
|
||||
checkpoint = torch.load(pickle_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif method == "full_model":
|
||||
model = torch.load(pickle_path, map_location='cpu')
|
||||
if hasattr(model, 'state_dict'):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
state_dict = model
|
||||
|
||||
print(f"✅ Successfully loaded with method: {method}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Method {method} failed: {e}")
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"All loading methods failed for {pickle_path}")
|
||||
|
||||
# Clean up the state dict
|
||||
state_dict = clean_state_dict(state_dict)
|
||||
|
||||
print(f"Found {len(state_dict)} parameters")
|
||||
|
||||
# Convert BF16 to FP32 if needed
|
||||
for key, tensor in state_dict.items():
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
state_dict[key] = tensor.to(torch.float32)
|
||||
print(f"Converted {key} from bfloat16 to float32")
|
||||
|
||||
# Save as safetensors
|
||||
print(f"Saving to safetensors: {safetensors_path}")
|
||||
os.makedirs(os.path.dirname(safetensors_path), exist_ok=True)
|
||||
save_file(state_dict, safetensors_path)
|
||||
|
||||
print("✅ T5 conversion complete!")
|
||||
return state_dict
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
"""
|
||||
Clean up state dict by removing unwanted prefixes or keys.
|
||||
"""
|
||||
cleaned = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# Remove common prefixes that might interfere
|
||||
clean_key = key
|
||||
|
||||
if clean_key.startswith('module.'):
|
||||
clean_key = clean_key[7:]
|
||||
|
||||
if clean_key.startswith('model.'):
|
||||
clean_key = clean_key[6:]
|
||||
|
||||
cleaned[clean_key] = value
|
||||
|
||||
if clean_key != key:
|
||||
print(f"Cleaned key: {key} -> {clean_key}")
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_with_your_torch_model(pickle_path: str, model_class, **model_kwargs):
|
||||
"""
|
||||
Load pickle weights into your specific PyTorch T5 model implementation.
|
||||
|
||||
Args:
|
||||
pickle_path: Path to pickle file
|
||||
model_class: Your T5Encoder class
|
||||
**model_kwargs: Arguments for your model constructor
|
||||
"""
|
||||
|
||||
print("Method 1: Loading into your PyTorch T5 model")
|
||||
|
||||
# Initialize your model
|
||||
model = model_class(**model_kwargs)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint = torch.load(pickle_path, map_location='cpu')
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
# Assume the dict IS the state dict
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
# Assume it's a model object
|
||||
state_dict = checkpoint.state_dict()
|
||||
|
||||
# Clean and load
|
||||
state_dict = clean_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
|
||||
|
||||
return model, state_dict
|
||||
|
||||
|
||||
def explore_pickle_file(pickle_path: str):
|
||||
"""
|
||||
Explore the contents of a pickle file to understand its structure.
|
||||
"""
|
||||
print(f"🔍 Exploring pickle file: {pickle_path}")
|
||||
|
||||
try:
|
||||
# Try loading with weights_only first (safer)
|
||||
print("\n--- Trying weights_only=True ---")
|
||||
try:
|
||||
data = torch.load(pickle_path, map_location='cpu', weights_only=True)
|
||||
print(f"✅ Loaded with weights_only=True")
|
||||
print(f"Type: {type(data)}")
|
||||
|
||||
if isinstance(data, dict):
|
||||
print(f"Dictionary with {len(data)} keys:")
|
||||
for i, key in enumerate(data.keys()):
|
||||
print(f" {key}: {type(data[key])}")
|
||||
if hasattr(data[key], 'shape'):
|
||||
print(f" Shape: {data[key].shape}")
|
||||
if i >= 9: # Show first 10 keys
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ weights_only=True failed: {e}")
|
||||
|
||||
# Try regular loading
|
||||
print("\n--- Trying regular loading ---")
|
||||
data = torch.load(pickle_path, map_location='cpu')
|
||||
print(f"✅ Loaded successfully")
|
||||
print(f"Type: {type(data)}")
|
||||
|
||||
if hasattr(data, 'state_dict'):
|
||||
print("📋 Found state_dict method")
|
||||
state_dict = data.state_dict()
|
||||
print(f"State dict has {len(state_dict)} parameters")
|
||||
|
||||
elif isinstance(data, dict):
|
||||
print(f"📋 Dictionary with keys: {list(data.keys())}")
|
||||
|
||||
# Check for common checkpoint keys
|
||||
if 'state_dict' in data:
|
||||
print("Found 'state_dict' key")
|
||||
print(f"state_dict has {len(data['state_dict'])} parameters")
|
||||
elif 'model' in data:
|
||||
print("Found 'model' key")
|
||||
print(f"model has {len(data['model'])} parameters")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load: {e}")
|
||||
|
||||
|
||||
def full_conversion_pipeline(
|
||||
pickle_path: str,
|
||||
safetensors_path: str,
|
||||
torch_model_class=None,
|
||||
model_kwargs=None
|
||||
):
|
||||
"""
|
||||
Complete pipeline: pickle -> safetensors -> ready for MLX conversion
|
||||
"""
|
||||
|
||||
print("🚀 Starting full conversion pipeline")
|
||||
print("="*50)
|
||||
|
||||
# Step 1: Explore the pickle file
|
||||
print("Step 1: Exploring pickle file structure")
|
||||
explore_pickle_file(pickle_path)
|
||||
|
||||
# Step 2: Convert to safetensors
|
||||
print(f"\nStep 2: Converting to safetensors")
|
||||
|
||||
# Try different loading methods
|
||||
for method in ["weights_only", "state_dict", "full_model"]:
|
||||
try:
|
||||
print(f"\nTrying load method: {method}")
|
||||
state_dict = convert_pickle_to_safetensors(
|
||||
pickle_path,
|
||||
safetensors_path,
|
||||
model_class=torch_model_class,
|
||||
model_kwargs=model_kwargs,
|
||||
load_method=method
|
||||
)
|
||||
print(f"✅ Success with method: {method}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Method {method} failed: {e}")
|
||||
continue
|
||||
else:
|
||||
print("❌ All methods failed!")
|
||||
return None
|
||||
|
||||
print(f"\n🎉 Conversion complete!")
|
||||
print(f"Safetensors file saved to: {safetensors_path}")
|
||||
print(f"Ready for MLX conversion using the previous script!")
|
||||
|
||||
return state_dict
|
||||
401
video/Wan2.2/wan/text2video.py
Normal file
401
video/Wan2.2/wan/text2video.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# MLX implementation of text2video.py
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, List, Dict, Any, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae2_1 import Wan2_1_VAE
|
||||
from .utils.fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
from .wan_model_io import convert_wan_2_2_safetensors_to_mlx, convert_multiple_wan_2_2_safetensors_to_mlx, load_wan_2_2_from_safetensors
|
||||
|
||||
class WanT2V:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir: str,
|
||||
device_id: int = 0,
|
||||
convert_model_dtype: bool = False,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components for MLX.
|
||||
|
||||
Args:
|
||||
config (EasyDict):
|
||||
Object containing model parameters initialized from config.py
|
||||
checkpoint_dir (`str`):
|
||||
Path to directory containing model checkpoints
|
||||
device_id (`int`, *optional*, defaults to 0):
|
||||
Device id (kept for compatibility, MLX handles device automatically)
|
||||
convert_model_dtype (`bool`, *optional*, defaults to False):
|
||||
Convert DiT model parameters dtype to 'config.param_dtype'.
|
||||
"""
|
||||
self.config = config
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.boundary = config.boundary
|
||||
# Convert PyTorch dtype to MLX dtype
|
||||
if str(config.param_dtype) == 'torch.bfloat16':
|
||||
self.param_dtype = mx.bfloat16
|
||||
elif str(config.param_dtype) == 'torch.float16':
|
||||
self.param_dtype = mx.float16
|
||||
elif str(config.param_dtype) == 'torch.float32':
|
||||
self.param_dtype = mx.float32
|
||||
else:
|
||||
self.param_dtype = mx.float32 # default
|
||||
|
||||
# Initialize T5 text encoder
|
||||
print(f"checkpoint_dir is: {checkpoint_dir}")
|
||||
t5_checkpoint_path = os.path.join(checkpoint_dir, config.t5_checkpoint)
|
||||
mlx_t5_path = t5_checkpoint_path.replace('.safetensors', '_mlx.safetensors')
|
||||
if not os.path.exists(mlx_t5_path):
|
||||
# Check if it's a .pth file that needs conversion
|
||||
pth_path = t5_checkpoint_path.replace('.safetensors', '.pth')
|
||||
if os.path.exists(pth_path):
|
||||
logging.info(f"Converting T5 PyTorch model to safetensors: {pth_path}")
|
||||
from .t5_torch_to_sf import convert_pickle_to_safetensors
|
||||
convert_pickle_to_safetensors(pth_path, t5_checkpoint_path, load_method="weights_only")
|
||||
# Convert torch safetensors to MLX safetensors
|
||||
from .t5_model_io import convert_safetensors_to_mlx_weights
|
||||
convert_safetensors_to_mlx_weights(t5_checkpoint_path, mlx_t5_path, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
raise FileNotFoundError(f"T5 checkpoint not found: {t5_checkpoint_path} or {pth_path}")
|
||||
|
||||
t5_checkpoint_path = mlx_t5_path # Use the MLX version
|
||||
logging.info(f"Loading T5 text encoder... from {t5_checkpoint_path}")
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
checkpoint_path=t5_checkpoint_path,
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer))
|
||||
|
||||
# Initialize VAE - with automatic conversion
|
||||
vae_path = os.path.join(checkpoint_dir, config.vae_checkpoint)
|
||||
if not os.path.exists(vae_path):
|
||||
# Check for PyTorch VAE file to convert
|
||||
pth_vae_path = vae_path.replace('_mlx.safetensors', '.pth')
|
||||
if not os.path.exists(pth_vae_path):
|
||||
# Try alternative naming
|
||||
pth_vae_path = os.path.join(checkpoint_dir, 'Wan2.1_VAE.pth')
|
||||
|
||||
if os.path.exists(pth_vae_path):
|
||||
logging.info(f"Converting VAE PyTorch model to MLX: {pth_vae_path}")
|
||||
from .vae_model_io import convert_pytorch_to_mlx
|
||||
convert_pytorch_to_mlx(pth_vae_path, vae_path, float16=(self.param_dtype == mx.float16))
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE checkpoint not found: {vae_path} or {pth_vae_path}")
|
||||
|
||||
logging.info("Loading VAE...")
|
||||
self.vae = Wan2_1_VAE(vae_pth=vae_path)
|
||||
|
||||
# Load low and high noise models
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
|
||||
# Helper function to load model with automatic conversion
|
||||
def load_model_with_conversion(checkpoint_dir, subfolder, config, param_dtype):
|
||||
"""Load model with automatic PyTorch to MLX conversion if needed."""
|
||||
|
||||
# Look for existing MLX files
|
||||
mlx_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model_mlx.safetensors")
|
||||
mlx_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_mlx_model*.safetensors")
|
||||
mlx_files = glob.glob(mlx_pattern)
|
||||
|
||||
# If no MLX files, convert PyTorch files
|
||||
if not os.path.exists(mlx_single) and not mlx_files:
|
||||
pytorch_single = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model.safetensors")
|
||||
pytorch_pattern = os.path.join(checkpoint_dir, subfolder, "diffusion_pytorch_model-*.safetensors")
|
||||
pytorch_files = glob.glob(pytorch_pattern)
|
||||
|
||||
if os.path.exists(pytorch_single):
|
||||
logging.info(f"Converting PyTorch model to MLX: {pytorch_single}")
|
||||
convert_wan_2_2_safetensors_to_mlx(
|
||||
pytorch_single,
|
||||
mlx_single,
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
elif pytorch_files:
|
||||
logging.info(f"Converting {len(pytorch_files)} PyTorch files to MLX")
|
||||
convert_multiple_wan_2_2_safetensors_to_mlx(
|
||||
os.path.join(checkpoint_dir, subfolder),
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
mlx_files = glob.glob(mlx_pattern) # Update file list
|
||||
else:
|
||||
raise FileNotFoundError(f"No model files found in {os.path.join(checkpoint_dir, subfolder)}")
|
||||
|
||||
# Create model
|
||||
model = WanModel(
|
||||
model_type='t2v',
|
||||
patch_size=config.patch_size,
|
||||
text_len=config.text_len,
|
||||
in_dim=16,
|
||||
dim=config.dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
freq_dim=config.freq_dim,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
window_size=getattr(config, 'window_size', (-1, -1)),
|
||||
qk_norm=getattr(config, 'qk_norm', True),
|
||||
cross_attn_norm=getattr(config, 'cross_attn_norm', True),
|
||||
eps=getattr(config, 'eps', 1e-6)
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if os.path.exists(mlx_single):
|
||||
logging.info(f"Loading single MLX file: {mlx_single}")
|
||||
model = load_wan_2_2_from_safetensors(mlx_single, model, float16=(param_dtype == mx.float16))
|
||||
else:
|
||||
logging.info(f"Loading multiple MLX files from: {os.path.join(checkpoint_dir, subfolder)}")
|
||||
model = load_wan_2_2_from_safetensors(
|
||||
os.path.join(checkpoint_dir, subfolder),
|
||||
model,
|
||||
float16=(param_dtype == mx.float16)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
# Load both models
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
logging.info("Loading low noise model")
|
||||
self.low_noise_model = load_model_with_conversion(
|
||||
checkpoint_dir,
|
||||
config.low_noise_checkpoint,
|
||||
self.config,
|
||||
self.param_dtype
|
||||
)
|
||||
self.low_noise_model = self._configure_model(self.low_noise_model, convert_model_dtype)
|
||||
|
||||
logging.info("Loading high noise model")
|
||||
self.high_noise_model = load_model_with_conversion(
|
||||
checkpoint_dir,
|
||||
config.high_noise_checkpoint,
|
||||
self.config,
|
||||
self.param_dtype
|
||||
)
|
||||
self.high_noise_model = self._configure_model(self.high_noise_model, convert_model_dtype)
|
||||
|
||||
self.sp_size = 1 # No sequence parallel in single device
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
def _configure_model(self, model: nn.Module, convert_model_dtype: bool) -> nn.Module:
|
||||
"""
|
||||
Configures a model object for MLX.
|
||||
|
||||
Args:
|
||||
model (nn.Module):
|
||||
The model instance to configure.
|
||||
convert_model_dtype (`bool`):
|
||||
Convert model parameters dtype to 'config.param_dtype'.
|
||||
|
||||
Returns:
|
||||
nn.Module:
|
||||
The configured model.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
if convert_model_dtype:
|
||||
# In MLX, we would need to manually convert parameters
|
||||
# This would be implemented in the actual model class
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
def _prepare_model_for_timestep(self, t, boundary, offload_model):
|
||||
"""
|
||||
Prepares and returns the required model for the current timestep.
|
||||
"""
|
||||
if t.item() >= boundary:
|
||||
required_model_name = 'high_noise_model'
|
||||
offload_model_name = 'low_noise_model'
|
||||
else:
|
||||
required_model_name = 'low_noise_model'
|
||||
offload_model_name = 'high_noise_model'
|
||||
|
||||
# MLX doesn't need the CPU offloading logic, just return the right model
|
||||
return getattr(self, required_model_name)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_prompt: str,
|
||||
size: Tuple[int, int] = (1280, 720),
|
||||
frame_num: int = 81,
|
||||
shift: float = 5.0,
|
||||
sample_solver: str = 'unipc',
|
||||
sampling_steps: int = 50,
|
||||
guide_scale: Union[float, Tuple[float, float]] = 5.0,
|
||||
n_prompt: str = "",
|
||||
seed: int = -1,
|
||||
offload_model: bool = True
|
||||
) -> Optional[mx.array]:
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
|
||||
Args:
|
||||
input_prompt (`str`):
|
||||
Text prompt for content generation
|
||||
size (`tuple[int]`, *optional*, defaults to (1280,720)):
|
||||
Controls video resolution, (width,height).
|
||||
frame_num (`int`, *optional*, defaults to 81):
|
||||
How many frames to sample from a video. The number should be 4n+1
|
||||
shift (`float`, *optional*, defaults to 5.0):
|
||||
Noise schedule shift parameter. Affects temporal dynamics
|
||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||
Solver used to sample the video.
|
||||
sampling_steps (`int`, *optional*, defaults to 50):
|
||||
Number of diffusion sampling steps.
|
||||
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
|
||||
Classifier-free guidance scale.
|
||||
n_prompt (`str`, *optional*, defaults to ""):
|
||||
Negative prompt for content exclusion.
|
||||
seed (`int`, *optional*, defaults to -1):
|
||||
Random seed for noise generation. If -1, use random seed.
|
||||
offload_model (`bool`, *optional*, defaults to True):
|
||||
Not used in MLX version (kept for compatibility)
|
||||
|
||||
Returns:
|
||||
mx.array:
|
||||
Generated video frames tensor. Dimensions: (C, N, H, W)
|
||||
"""
|
||||
# Preprocess
|
||||
guide_scale = (guide_scale, guide_scale) if isinstance(
|
||||
guide_scale, float) else guide_scale
|
||||
|
||||
F = frame_num
|
||||
target_shape = (
|
||||
self.vae.model.z_dim,
|
||||
(F - 1) // self.vae_stride[0] + 1,
|
||||
size[1] // self.vae_stride[1],
|
||||
size[0] // self.vae_stride[2]
|
||||
)
|
||||
|
||||
seq_len = math.ceil(
|
||||
(target_shape[2] * target_shape[3]) /
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size
|
||||
) * self.sp_size
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
|
||||
# Set random seed
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Encode text prompts
|
||||
context = self.text_encoder([input_prompt])
|
||||
context_null = self.text_encoder([n_prompt])
|
||||
|
||||
# Generate initial noise
|
||||
noise = [
|
||||
mx.random.normal(
|
||||
shape=target_shape,
|
||||
dtype=mx.float32
|
||||
)
|
||||
]
|
||||
|
||||
# Set boundary
|
||||
boundary = self.boundary * self.num_train_timesteps
|
||||
|
||||
# Initialize scheduler
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, shift=shift
|
||||
)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False
|
||||
)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
sigmas=sampling_sigmas
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# Sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Denoising loop
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = mx.array([t])
|
||||
|
||||
# Select model based on timestep
|
||||
model = self._prepare_model_for_timestep(
|
||||
t, boundary, offload_model
|
||||
)
|
||||
sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0]
|
||||
|
||||
# Model predictions
|
||||
noise_pred_cond = model(
|
||||
latent_model_input, t=timestep, **arg_c
|
||||
)[0]
|
||||
mx.eval(noise_pred_cond) # Force evaluation
|
||||
|
||||
noise_pred_uncond = model(
|
||||
latent_model_input, t=timestep, **arg_null
|
||||
)[0]
|
||||
mx.eval(noise_pred_uncond) # Force evaluation
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_uncond + sample_guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
mx.eval(noise_pred) # Force evaluation
|
||||
|
||||
# Scheduler step
|
||||
temp_x0 = sample_scheduler.step(
|
||||
mx.expand_dims(noise_pred, axis=0),
|
||||
t,
|
||||
mx.expand_dims(latents[0], axis=0),
|
||||
return_dict=False
|
||||
)[0]
|
||||
latents = [mx.squeeze(temp_x0, axis=0)]
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Decode final latents
|
||||
x0 = latents
|
||||
videos = self.vae.decode(x0)
|
||||
|
||||
# Cleanup
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
|
||||
return videos[0]
|
||||
12
video/Wan2.2/wan/utils/__init__.py
Normal file
12
video/Wan2.2/wan/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
__all__ = [
|
||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
||||
]
|
||||
562
video/Wan2.2/wan/utils/fm_solvers.py
Normal file
562
video/Wan2.2/wan/utils/fm_solvers.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
|
||||
return sigma
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps=None,
|
||||
device=None,
|
||||
timesteps=None,
|
||||
sigmas=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError(
|
||||
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
||||
)
|
||||
if timesteps is not None:
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SchedulerOutput:
|
||||
"""Output class for scheduler step results."""
|
||||
def __init__(self, prev_sample: mx.array):
|
||||
self.prev_sample = prev_sample
|
||||
|
||||
|
||||
class FlowDPMSolverMultistepScheduler:
|
||||
"""
|
||||
MLX implementation of FlowDPMSolverMultistepScheduler.
|
||||
A fast dedicated high-order solver for diffusion ODEs.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "flow_prediction",
|
||||
shift: Optional[float] = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
invert_sigmas: bool = False,
|
||||
):
|
||||
# Store configuration
|
||||
self.config = {
|
||||
'num_train_timesteps': num_train_timesteps,
|
||||
'solver_order': solver_order,
|
||||
'prediction_type': prediction_type,
|
||||
'shift': shift,
|
||||
'use_dynamic_shifting': use_dynamic_shifting,
|
||||
'thresholding': thresholding,
|
||||
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
|
||||
'sample_max_value': sample_max_value,
|
||||
'algorithm_type': algorithm_type,
|
||||
'solver_type': solver_type,
|
||||
'lower_order_final': lower_order_final,
|
||||
'euler_at_final': euler_at_final,
|
||||
'final_sigmas_type': final_sigmas_type,
|
||||
'lambda_min_clipped': lambda_min_clipped,
|
||||
'variance_type': variance_type,
|
||||
'invert_sigmas': invert_sigmas,
|
||||
}
|
||||
|
||||
# Validate algorithm type
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.config['algorithm_type'] = "dpmsolver++"
|
||||
else:
|
||||
raise NotImplementedError(f"{algorithm_type} is not implemented")
|
||||
|
||||
# Validate solver type
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
self.config['solver_type'] = "midpoint"
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented")
|
||||
|
||||
# Initialize scheduling
|
||||
self.num_inference_steps = None
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = mx.array(sigmas, dtype=mx.float32)
|
||||
|
||||
if not use_dynamic_shifting:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.sigmas = sigmas
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigma_min = float(self.sigmas[-1])
|
||||
self.sigma_max = float(self.sigmas[0])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Union[int, None] = None,
|
||||
device: Union[str, None] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[Union[float, None]] = None,
|
||||
shift: Optional[Union[float, None]] = None,
|
||||
):
|
||||
"""Sets the discrete timesteps used for the diffusion chain."""
|
||||
if self.config['use_dynamic_shifting'] and mu is None:
|
||||
raise ValueError(
|
||||
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
||||
)
|
||||
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
|
||||
|
||||
if self.config['use_dynamic_shifting']:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
if shift is None:
|
||||
shift = self.config['shift']
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
if self.config['final_sigmas_type'] == "sigma_min":
|
||||
sigma_last = self.sigma_min
|
||||
elif self.config['final_sigmas_type'] == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
|
||||
)
|
||||
|
||||
timesteps = sigmas * self.config['num_train_timesteps']
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(timesteps, dtype=mx.int64)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [None] * self.config['solver_order']
|
||||
self.lower_order_nums = 0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _threshold_sample(self, sample: mx.array) -> mx.array:
|
||||
"""Dynamic thresholding method."""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
# Flatten sample for quantile calculation
|
||||
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = mx.abs(sample_flat)
|
||||
|
||||
# Compute quantile
|
||||
s = mx.quantile(
|
||||
abs_sample,
|
||||
self.config['dynamic_thresholding_ratio'],
|
||||
axis=1,
|
||||
keepdims=True
|
||||
)
|
||||
s = mx.clip(s, 1, self.config['sample_max_value'])
|
||||
|
||||
# Threshold and normalize
|
||||
sample_flat = mx.clip(sample_flat, -s, s) / s
|
||||
|
||||
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
|
||||
return sample.astype(dtype)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config['num_train_timesteps']
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
return 1 - sigma, sigma
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: mx.array):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""Convert model output to the corresponding type the algorithm needs."""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model
|
||||
if self.config['algorithm_type'] in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model
|
||||
elif self.config['algorithm_type'] in ["dpmsolver", "sde-dpmsolver"]:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
epsilon = sample - (1 - sigma_t) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be "
|
||||
f"'flow_prediction' for the FlowDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = model_output + x0_pred
|
||||
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the first-order DPMSolver (equivalent to DDIM)."""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s = mx.log(alpha_s) - mx.log(sigma_s)
|
||||
|
||||
h = lambda_t - lambda_s
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mx.exp(-h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mx.exp(h) - 1.0)) * model_output
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * model_output +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(alpha_t / alpha_s) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * model_output +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
noise: Optional[mx.array] = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the second-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 -
|
||||
0.5 * (alpha_t * (mx.exp(-h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
0.5 * (sigma_t * (mx.exp(h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
0.5 * (alpha_t * (1 - mx.exp(-2.0 * h))) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0 * mx.exp(-h)) * sample +
|
||||
(alpha_t * (1 - mx.exp(-2.0 * h))) * D0 +
|
||||
(alpha_t * ((1.0 - mx.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(1.0 - mx.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config['algorithm_type'] == "sde-dpmsolver":
|
||||
assert noise is not None
|
||||
if self.config['solver_type'] == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
elif self.config['solver_type'] == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
2.0 * (sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
2.0 * (sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 +
|
||||
sigma_t * mx.sqrt(mx.exp(2 * h) - 1.0) * noise
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[mx.array],
|
||||
sample: mx.array,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the third-order multistep DPMSolver."""
|
||||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
self.sigmas[self.step_index - 1],
|
||||
self.sigmas[self.step_index - 2],
|
||||
)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
lambda_s1 = mx.log(alpha_s1) - mx.log(sigma_s1)
|
||||
lambda_s2 = mx.log(alpha_s2) - mx.log(sigma_s2)
|
||||
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
|
||||
if self.config['algorithm_type'] == "dpmsolver++":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample -
|
||||
(alpha_t * (mx.exp(-h) - 1.0)) * D0 +
|
||||
(alpha_t * ((mx.exp(-h) - 1.0) / h + 1.0)) * D1 -
|
||||
(alpha_t * ((mx.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
elif self.config['algorithm_type'] == "dpmsolver":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample -
|
||||
(sigma_t * (mx.exp(h) - 1.0)) * D0 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0) / h - 1.0)) * D1 -
|
||||
(sigma_t * ((mx.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = mx.where(schedule_timesteps == timestep)[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return int(indices[pos])
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""Initialize the step_index counter for the scheduler."""
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep: Union[int, mx.array],
|
||||
sample: mx.array,
|
||||
generator=None,
|
||||
variance_noise: Optional[mx.array] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""Predict the sample from the previous timestep."""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Improve numerical stability for small number of steps
|
||||
lower_order_final = (
|
||||
(self.step_index == len(self.timesteps) - 1) and
|
||||
(self.config['euler_at_final'] or
|
||||
(self.config['lower_order_final'] and len(self.timesteps) < 15) or
|
||||
self.config['final_sigmas_type'] == "zero")
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and
|
||||
self.config['lower_order_final'] and
|
||||
len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, sample=sample)
|
||||
for i in range(self.config['solver_order'] - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
# Upcast to avoid precision issues
|
||||
sample = sample.astype(mx.float32)
|
||||
|
||||
# Generate noise if needed for SDE variants
|
||||
if self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = mx.random.normal(model_output.shape, dtype=mx.float32)
|
||||
elif self.config['algorithm_type'] in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.astype(mx.float32)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
if self.config['solver_order'] == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(
|
||||
model_output, sample=sample, noise=noise
|
||||
)
|
||||
elif self.config['solver_order'] == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, sample=sample, noise=noise
|
||||
)
|
||||
else:
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, sample=sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config['solver_order']:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# Cast sample back to expected dtype
|
||||
prev_sample = prev_sample.astype(model_output.dtype)
|
||||
|
||||
# Increase step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
|
||||
"""Scale model input - no scaling needed for this scheduler."""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: mx.array,
|
||||
noise: mx.array,
|
||||
timesteps: mx.array,
|
||||
) -> mx.array:
|
||||
"""Add noise to original samples."""
|
||||
sigmas = self.sigmas.astype(original_samples.dtype)
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Get step indices
|
||||
if self.begin_index is None:
|
||||
step_indices = [
|
||||
self.index_for_timestep(t, schedule_timesteps)
|
||||
for t in timesteps
|
||||
]
|
||||
elif self.step_index is not None:
|
||||
step_indices = [self.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices]
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = mx.expand_dims(sigma, -1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config['num_train_timesteps']
|
||||
546
video/Wan2.2/wan/utils/fm_solvers_unipc.py
Normal file
546
video/Wan2.2/wan/utils/fm_solvers_unipc.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SchedulerOutput:
|
||||
"""Output class for scheduler step results."""
|
||||
def __init__(self, prev_sample: mx.array):
|
||||
self.prev_sample = prev_sample
|
||||
|
||||
|
||||
class FlowUniPCMultistepScheduler:
|
||||
"""
|
||||
MLX implementation of UniPCMultistepScheduler.
|
||||
A training-free framework designed for the fast sampling of diffusion models.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "flow_prediction",
|
||||
shift: Optional[float] = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
predict_x0: bool = True,
|
||||
solver_type: str = "bh2",
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: List[int] = [],
|
||||
solver_p = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
final_sigmas_type: Optional[str] = "zero",
|
||||
):
|
||||
# Store configuration
|
||||
self.config = {
|
||||
'num_train_timesteps': num_train_timesteps,
|
||||
'solver_order': solver_order,
|
||||
'prediction_type': prediction_type,
|
||||
'shift': shift,
|
||||
'use_dynamic_shifting': use_dynamic_shifting,
|
||||
'thresholding': thresholding,
|
||||
'dynamic_thresholding_ratio': dynamic_thresholding_ratio,
|
||||
'sample_max_value': sample_max_value,
|
||||
'predict_x0': predict_x0,
|
||||
'solver_type': solver_type,
|
||||
'lower_order_final': lower_order_final,
|
||||
'disable_corrector': disable_corrector,
|
||||
'solver_p': solver_p,
|
||||
'timestep_spacing': timestep_spacing,
|
||||
'steps_offset': steps_offset,
|
||||
'final_sigmas_type': final_sigmas_type,
|
||||
}
|
||||
|
||||
# Validate solver type
|
||||
if solver_type not in ["bh1", "bh2"]:
|
||||
if solver_type in ["midpoint", "heun", "logrho"]:
|
||||
self.config['solver_type'] = "bh2"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{solver_type} is not implemented for {self.__class__}"
|
||||
)
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = mx.array(sigmas, dtype=mx.float32)
|
||||
|
||||
if not use_dynamic_shifting:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.sigmas = sigmas
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.timestep_list = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigma_min = float(self.sigmas[-1])
|
||||
self.sigma_max = float(self.sigmas[0])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""The index counter for current timestep."""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""The index for the first timestep."""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""Sets the begin index for the scheduler."""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Union[int, None] = None,
|
||||
device: Union[str, None] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[Union[float, None]] = None,
|
||||
shift: Optional[Union[float, None]] = None,
|
||||
):
|
||||
"""Sets the discrete timesteps used for the diffusion chain."""
|
||||
if self.config['use_dynamic_shifting'] and mu is None:
|
||||
raise ValueError(
|
||||
"you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
||||
)
|
||||
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
|
||||
|
||||
if self.config['use_dynamic_shifting']:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
if shift is None:
|
||||
shift = self.config['shift']
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
if self.config['final_sigmas_type'] == "sigma_min":
|
||||
sigma_last = self.sigma_min
|
||||
elif self.config['final_sigmas_type'] == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config['final_sigmas_type']}"
|
||||
)
|
||||
|
||||
timesteps = sigmas * self.config['num_train_timesteps']
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(timesteps, dtype=mx.int64)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [None] * self.config['solver_order']
|
||||
self.lower_order_nums = 0
|
||||
self.last_sample = None
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
||||
|
||||
# add an index counter for schedulers
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def _threshold_sample(self, sample: mx.array) -> mx.array:
|
||||
"""Dynamic thresholding method."""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
# Flatten sample for quantile calculation
|
||||
sample_flat = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = mx.abs(sample_flat)
|
||||
|
||||
# Compute quantile
|
||||
s = mx.quantile(
|
||||
abs_sample,
|
||||
self.config['dynamic_thresholding_ratio'],
|
||||
axis=1,
|
||||
keepdims=True
|
||||
)
|
||||
s = mx.clip(s, 1, self.config['sample_max_value'])
|
||||
|
||||
# Threshold and normalize
|
||||
sample_flat = mx.clip(sample_flat, -s, s) / s
|
||||
|
||||
sample = sample_flat.reshape(batch_size, channels, *remaining_dims)
|
||||
return sample.astype(dtype)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config['num_train_timesteps']
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
return 1 - sigma, sigma
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: mx.array):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""Convert the model output to the corresponding type the UniPC algorithm needs."""
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
|
||||
f"for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config['prediction_type'] == "flow_prediction":
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
epsilon = sample - (1 - sigma_t) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config['prediction_type']} must be 'flow_prediction' "
|
||||
f"for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config['thresholding']:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = model_output + x0_pred
|
||||
|
||||
return epsilon
|
||||
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
sample: mx.array = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the UniP (B(h) version)."""
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0 = self.timestep_list[-1]
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
if self.solver_p:
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - i
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = mx.array(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = mx.exp(hh) - 1 # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config['solver_type'] == "bh1":
|
||||
B_h = hh
|
||||
elif self.config['solver_type'] == "bh2":
|
||||
B_h = mx.exp(hh) - 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(mx.power(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = mx.stack(R)
|
||||
b = mx.array(b)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = mx.stack(D1s, axis=1) # (B, K)
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = mx.array([0.5], dtype=x.dtype)
|
||||
else:
|
||||
rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1], stream=mx.cpu).astype(x.dtype)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = mx.sum(rhos_p[:, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
|
||||
x_t = x_t.astype(x.dtype)
|
||||
return x_t
|
||||
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: mx.array,
|
||||
last_sample: mx.array = None,
|
||||
this_sample: mx.array = None,
|
||||
order: int = None,
|
||||
**kwargs,
|
||||
) -> mx.array:
|
||||
"""One step for the UniC (B(h) version)."""
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = mx.log(alpha_t) - mx.log(sigma_t)
|
||||
lambda_s0 = mx.log(alpha_s0) - mx.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - (i + 1)
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = mx.log(alpha_si) - mx.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = mx.array(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = mx.exp(hh) - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config['solver_type'] == "bh1":
|
||||
B_h = hh
|
||||
elif self.config['solver_type'] == "bh2":
|
||||
B_h = mx.exp(hh) - 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(mx.power(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = mx.stack(R)
|
||||
b = mx.array(b)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = mx.stack(D1s, axis=1)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = mx.array([0.5], dtype=x.dtype)
|
||||
else:
|
||||
rhos_c = mx.linalg.solve(R, b, stream=mx.cpu).astype(x.dtype)
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = mx.sum(rhos_c[:-1, None, None, None] * D1s, axis=0)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
|
||||
x_t = x_t.astype(x.dtype)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
condition = schedule_timesteps == timestep
|
||||
indices = mx.argmax(condition.astype(mx.int32))
|
||||
|
||||
# Convert scalar to int and return
|
||||
return int(indices)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""Initialize the step_index counter for the scheduler."""
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep: Union[int, mx.array],
|
||||
sample: mx.array,
|
||||
return_dict: bool = True,
|
||||
generator=None
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""Predict the sample from the previous timestep."""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
use_corrector = (
|
||||
self.step_index > 0 and
|
||||
self.step_index - 1 not in self.disable_corrector and
|
||||
self.last_sample is not None
|
||||
)
|
||||
|
||||
model_output_convert = self.convert_model_output(
|
||||
model_output, sample=sample
|
||||
)
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
for i in range(self.config['solver_order'] - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
|
||||
self.model_outputs[-1] = model_output_convert
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config['lower_order_final']:
|
||||
this_order = min(
|
||||
self.config['solver_order'],
|
||||
len(self.timesteps) - self.step_index
|
||||
)
|
||||
else:
|
||||
this_order = self.config['solver_order']
|
||||
|
||||
self.this_order = min(this_order, self.lower_order_nums + 1)
|
||||
assert self.this_order > 0
|
||||
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output,
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config['solver_order']:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# Increase step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: mx.array, *args, **kwargs) -> mx.array:
|
||||
"""Scale model input - no scaling needed for this scheduler."""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: mx.array,
|
||||
noise: mx.array,
|
||||
timesteps: mx.array,
|
||||
) -> mx.array:
|
||||
"""Add noise to original samples."""
|
||||
sigmas = self.sigmas.astype(original_samples.dtype)
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Get step indices
|
||||
if self.begin_index is None:
|
||||
step_indices = [
|
||||
self.index_for_timestep(t, schedule_timesteps)
|
||||
for t in timesteps
|
||||
]
|
||||
elif self.step_index is not None:
|
||||
step_indices = [self.step_index] * timesteps.shape[0]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices]
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = mx.expand_dims(sigma, -1)
|
||||
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config['num_train_timesteps']
|
||||
363
video/Wan2.2/wan/utils/qwen_vl_utils.py
Normal file
363
video/Wan2.2/wan/utils/qwen_vl_utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copied from https://github.com/kq-chen/qwen-vl-utils
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import io, transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 4 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
||||
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
||||
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 768
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image],
|
||||
size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
||||
if "image" in ele:
|
||||
image = ele["image"]
|
||||
else:
|
||||
image = ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
image_obj = Image.open(requests.get(image, stream=True).raw)
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = image_obj.convert("RGB")
|
||||
## resize
|
||||
if "resized_height" in ele and "resized_width" in ele:
|
||||
resized_height, resized_width = smart_resize(
|
||||
ele["resized_height"],
|
||||
ele["resized_width"],
|
||||
factor=size_factor,
|
||||
)
|
||||
else:
|
||||
width, height = image.size
|
||||
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
||||
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def smart_nframes(
|
||||
ele: dict,
|
||||
total_frames: int,
|
||||
video_fps: int | float,
|
||||
) -> int:
|
||||
"""calculate the number of frames for video used for model inputs.
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support either `fps` or `nframes`:
|
||||
- nframes: the number of frames to extract for model inputs.
|
||||
- fps: the fps to extract frames for model inputs.
|
||||
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
||||
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
||||
total_frames (int): the original total number of frames of the video.
|
||||
video_fps (int | float): the original fps of the video.
|
||||
|
||||
Raises:
|
||||
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
||||
|
||||
Returns:
|
||||
int: the number of frames for video used for model inputs.
|
||||
"""
|
||||
assert not ("fps" in ele and
|
||||
"nframes" in ele), "Only accept either `fps` or `nframes`"
|
||||
if "nframes" in ele:
|
||||
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
||||
else:
|
||||
fps = ele.get("fps", FPS)
|
||||
min_frames = ceil_by_factor(
|
||||
ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
||||
max_frames = floor_by_factor(
|
||||
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
|
||||
FRAME_FACTOR)
|
||||
nframes = total_frames / video_fps * fps
|
||||
nframes = min(max(nframes, min_frames), max_frames)
|
||||
nframes = round_by_factor(nframes, FRAME_FACTOR)
|
||||
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
||||
raise ValueError(
|
||||
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
|
||||
)
|
||||
return nframes
|
||||
|
||||
|
||||
def _read_video_torchvision(ele: dict,) -> torch.Tensor:
|
||||
"""read video using torchvision.io.read_video
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support keys:
|
||||
- video: the path of video. support "file://", "http://", "https://" and local path.
|
||||
- video_start: the start time of video.
|
||||
- video_end: the end time of video.
|
||||
Returns:
|
||||
torch.Tensor: the video tensor with shape (T, C, H, W).
|
||||
"""
|
||||
video_path = ele["video"]
|
||||
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
||||
if "http://" in video_path or "https://" in video_path:
|
||||
warnings.warn(
|
||||
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
|
||||
)
|
||||
if "file://" in video_path:
|
||||
video_path = video_path[7:]
|
||||
st = time.time()
|
||||
video, audio, info = io.read_video(
|
||||
video_path,
|
||||
start_pts=ele.get("video_start", 0.0),
|
||||
end_pts=ele.get("video_end", None),
|
||||
pts_unit="sec",
|
||||
output_format="TCHW",
|
||||
)
|
||||
total_frames, video_fps = video.size(0), info["video_fps"]
|
||||
logger.info(
|
||||
f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
|
||||
)
|
||||
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
||||
video = video[idx]
|
||||
return video
|
||||
|
||||
|
||||
def is_decord_available() -> bool:
|
||||
import importlib.util
|
||||
|
||||
return importlib.util.find_spec("decord") is not None
|
||||
|
||||
|
||||
def _read_video_decord(ele: dict,) -> torch.Tensor:
|
||||
"""read video using decord.VideoReader
|
||||
|
||||
Args:
|
||||
ele (dict): a dict contains the configuration of video.
|
||||
support keys:
|
||||
- video: the path of video. support "file://", "http://", "https://" and local path.
|
||||
- video_start: the start time of video.
|
||||
- video_end: the end time of video.
|
||||
Returns:
|
||||
torch.Tensor: the video tensor with shape (T, C, H, W).
|
||||
"""
|
||||
import decord
|
||||
video_path = ele["video"]
|
||||
st = time.time()
|
||||
vr = decord.VideoReader(video_path)
|
||||
# TODO: support start_pts and end_pts
|
||||
if 'video_start' in ele or 'video_end' in ele:
|
||||
raise NotImplementedError(
|
||||
"not support start_pts and end_pts in decord for now.")
|
||||
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
||||
logger.info(
|
||||
f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
|
||||
)
|
||||
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
||||
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
||||
video = vr.get_batch(idx).asnumpy()
|
||||
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
||||
return video
|
||||
|
||||
|
||||
VIDEO_READER_BACKENDS = {
|
||||
"decord": _read_video_decord,
|
||||
"torchvision": _read_video_torchvision,
|
||||
}
|
||||
|
||||
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_video_reader_backend() -> str:
|
||||
if FORCE_QWENVL_VIDEO_READER is not None:
|
||||
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
||||
elif is_decord_available():
|
||||
video_reader_backend = "decord"
|
||||
else:
|
||||
video_reader_backend = "torchvision"
|
||||
logger.info(
|
||||
f"qwen-vl-utils using {video_reader_backend} to read video.",
|
||||
file=sys.stderr)
|
||||
return video_reader_backend
|
||||
|
||||
|
||||
def fetch_video(
|
||||
ele: dict,
|
||||
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
|
||||
if isinstance(ele["video"], str):
|
||||
video_reader_backend = get_video_reader_backend()
|
||||
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
||||
nframes, _, height, width = video.shape
|
||||
|
||||
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
||||
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
||||
max_pixels = max(
|
||||
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
|
||||
int(min_pixels * 1.05))
|
||||
max_pixels = ele.get("max_pixels", max_pixels)
|
||||
if "resized_height" in ele and "resized_width" in ele:
|
||||
resized_height, resized_width = smart_resize(
|
||||
ele["resized_height"],
|
||||
ele["resized_width"],
|
||||
factor=image_factor,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=image_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
video = transforms.functional.resize(
|
||||
video,
|
||||
[resized_height, resized_width],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
).float()
|
||||
return video
|
||||
else:
|
||||
assert isinstance(ele["video"], (list, tuple))
|
||||
process_info = ele.copy()
|
||||
process_info.pop("type", None)
|
||||
process_info.pop("video", None)
|
||||
images = [
|
||||
fetch_image({
|
||||
"image": video_element,
|
||||
**process_info
|
||||
},
|
||||
size_factor=image_factor)
|
||||
for video_element in ele["video"]
|
||||
]
|
||||
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
||||
if len(images) < nframes:
|
||||
images.extend([images[-1]] * (nframes - len(images)))
|
||||
return images
|
||||
|
||||
|
||||
def extract_vision_info(
|
||||
conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if ("image" in ele or "image_url" in ele or
|
||||
"video" in ele or
|
||||
ele["type"] in ("image", "image_url", "video")):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
def process_vision_info(
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
|
||||
None]:
|
||||
vision_infos = extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
elif "video" in vision_info:
|
||||
video_inputs.append(fetch_video(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
return image_inputs, video_inputs
|
||||
147
video/Wan2.2/wan/utils/system_prompt.py
Normal file
147
video/Wan2.2/wan/utils/system_prompt.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
T2V_A14B_ZH_SYS_PROMPT = \
|
||||
''' 你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质Prompt,使其完整、具有表现力。
|
||||
任务要求:
|
||||
1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择部分合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有
|
||||
时间:["白天", "夜晚", "黎明", "日出"], 可以不选, 如果prompt没有特别说明则选白天 !
|
||||
光源:[日光", "人工光", "月光", "实用光", "火光", "荧光", "阴天光", "晴天光"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等)
|
||||
光线强度:["柔光", "硬光"],
|
||||
光线角度:["顶光", "侧光", "底光", "边缘光",]
|
||||
色调:["暖色调","冷色调", "混合色调"]
|
||||
镜头尺寸:["中景", "中近景", "全景","中全景","近景", "特写", "极端全景"]若无特殊要求,默认选择中景或全景
|
||||
拍摄角度:["过肩镜头角度拍摄", "低角度拍摄", "高角度拍摄","倾斜角度拍摄", "航拍","俯视角度拍摄"],如果原始prompt中有运镜的描述,则不要添加此项!
|
||||
构图:["中心构图","平衡构图","右侧重构图", "左侧重构图", "对称构图", "短边构图"] 若无特殊要求,默认选择中心构图
|
||||
2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节;
|
||||
3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。;
|
||||
4. 对于prompt中的动作,详细解释运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等),对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。
|
||||
5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写;
|
||||
6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光;
|
||||
7. 改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出
|
||||
8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。
|
||||
9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。
|
||||
10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。
|
||||
生成的 prompt 示例:
|
||||
1.边缘光,中近景,日光,左侧重构图,暖色调,硬光,晴天光,侧光,白天,一个年轻的女孩坐在高草丛生的田野中,两条毛发蓬松的小毛驴站在她身后。女孩大约十一二岁,穿着简单的碎花裙子,头发扎成两条麻花辫,脸上带着纯真的笑容。她双腿交叉坐下,双手轻轻抚弄身旁的野花。小毛驴体型健壮,耳朵竖起,好奇地望着镜头方向。阳光洒在田野上,营造出温暖自然的画面感。
|
||||
2.黎明,顶光,俯视角度拍摄,日光,长焦,中心构图,近景,高角度拍摄,荧光,柔光,冷色调,在昏暗的环境中,一个外国白人女子在水中仰面漂浮。俯拍近景镜头中,她有着棕色的短发,脸上有几颗雀斑。随着镜头下摇,她转过头来,面向右侧,水面上泛起一圈涟漪。虚化的背景一片漆黑,只有微弱的光线照亮了女子的脸庞和水面的一部分区域,水面呈现蓝色。女子穿着一件蓝色的吊带,肩膀裸露在外。
|
||||
3.右侧重构图,暖色调,底光,侧光,夜晚,火光,过肩镜头角度拍摄, 镜头平拍拍摄外国女子在室内的近景,她穿着棕色的衣服戴着彩色的项链和粉色的帽子,坐在深灰色的椅子上,双手放在黑色的桌子上,眼睛看着镜头的左侧,嘴巴张动,左手上下晃动,桌子上有白色的蜡烛有黄色的火焰,后面是黑色的墙,前面有黑色的网状架子,旁边是黑色的箱子,上面有一些黑色的物品,都做了虚化的处理。
|
||||
4. 二次元厚涂动漫插画,一个猫耳兽耳白人少女手持文件夹摇晃,神情略带不满。她深紫色长发,红色眼睛,身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,上面写着黑体中文"紫阳"。淡黄色调室内背景,隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。
|
||||
'''
|
||||
|
||||
|
||||
T2V_A14B_EN_SYS_PROMPT = \
|
||||
'''你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质(英文)Prompt,使其完整、具有表现力注意,输出必须是英文!
|
||||
任务要求:
|
||||
1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择不超过4种合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有
|
||||
时间:["Day time", "Night time" "Dawn time","Sunrise time"], 如果prompt没有特别说明则选 Day time!!!
|
||||
光源:["Daylight", "Artificial lighting", "Moonlight", "Practical lighting", "Firelight","Fluorescent lighting", "Overcast lighting" "Sunny lighting"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等)
|
||||
光线强度:["Soft lighting", "Hard lighting"],
|
||||
色调:["Warm colors","Cool colors", "Mixed colors"]
|
||||
光线角度:["Top lighting", "Side lighting", "Underlighting", "Edge lighting"]
|
||||
镜头尺寸:["Medium shot", "Medium close-up shot", "Wide shot","Medium wide shot","Close-up shot", "Extreme close-up shot", "Extreme wide shot"]若无特殊要求,默认选择Medium shot或Wide shot
|
||||
拍摄角度:["Over-the-shoulder shot", ""Low angle shot", "High angle shot","Dutch angle shot", "Aerial shot","Overhead shot"] 若原始prompt中有运镜的描述,则取消添加此项。
|
||||
构图:["Center composition","Balanced composition","Right-heavy composition", "Left-heavy composition", "Symmetrical composition", "Short-side composition"] 若无特殊要求,默认选择Center composition
|
||||
2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节;
|
||||
3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。;
|
||||
4. 对于prompt中的动作,详细描述运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等,对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。
|
||||
5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写;
|
||||
6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光;
|
||||
7. 输出必须是全英文,改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出
|
||||
8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。
|
||||
9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。
|
||||
10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。
|
||||
生成的 prompt 示例:
|
||||
1.Edge lighting, medium close-up shot, daylight, left-heavy composition. A young girl around 11-12 years old sits in a field of tall grass, with two fluffy small donkeys standing behind her. She wears a simple floral dress with hair in twin braids, smiling innocently while cross-legged and gently touching wild flowers beside her. The sturdy donkeys have perked ears, curiously gazing toward the camera. Sunlight bathes the field, creating a warm natural atmosphere.
|
||||
2.Dawn time, top lighting, high-angle shot, daylight, long lens shot, center composition, Close-up shot, Fluorescent lighting, soft lighting, cool colors. In dim surroundings, a Caucasian woman floats on her back in water. The俯拍close-up shows her brown short hair and freckled face. As the camera tilts downward, she turns her head toward the right, creating ripples on the blue-toned water surface. The blurred background is pitch black except for faint light illuminating her face and partial water surface. She wears a blue sleeveless top with bare shoulders.
|
||||
3.Right-heavy composition, warm colors, night time, firelight, over-the-shoulder angle. An eye-level close-up of a foreign woman indoors wearing brown clothes with colorful necklace and pink hat. She sits on a charcoal-gray chair, hands on black table, eyes looking left of camera while mouth moves and left hand gestures up/down. White candles with yellow flames sit on the table. Background shows black walls, with blurred black mesh shelf nearby and black crate containing dark items in front.
|
||||
4."Anime-style thick-painted style. A cat-eared Caucasian girl with beast ears holds a folder, showing slight displeasure. Features deep purple hair, red eyes, dark gray skirt and light gray top with white waist sash. A name tag labeled 'Ziyang' in bold Chinese characters hangs on her chest. Pale yellow indoor background with faint furniture outlines. A pink halo floats above her head. Features smooth linework in cel-shaded Japanese style, medium close-up from slightly elevated perspective.
|
||||
'''
|
||||
|
||||
|
||||
I2V_A14B_ZH_SYS_PROMPT = \
|
||||
'''你是一个视频描述提示词的改写专家,你的任务是根据用户给你输入的图像,对提供的视频描述提示词进行改写,你要强调潜在的动态内容。具体要求如下
|
||||
用户输入的语言可能含有多样化的描述,如markdown文档格式、指令格式,长度过长或者过短,你需要根据图片的内容和用户的输入的提示词,尽可能提取用户输入的提示词和图片关联信息。
|
||||
你改写的视频描述结果要尽可能保留提供给你的视频描述提示词中动态部分,保留主体的动作。
|
||||
你要根据图像,强调并简化视频描述提示词中的图像主体,如果用户只提供了动作,你要根据图像内容合理补充,如“跳舞”补充称“一个女孩在跳舞”
|
||||
如果用户输入的提示词过长,你需要提炼潜在的动作过程
|
||||
如果用户输入的提示词过短,综合用户输入的提示词以及画面内容,合理的增加潜在的运动信息
|
||||
你要根据图像,保留并强调视频描述提示词中关于运镜手段的描述,如“镜头上摇”,“镜头从左到右”,“镜头从右到左”等等,你要保留,如“镜头拍摄两个男人打斗,他们先是躺在地上,随后镜头向上移动,拍摄他们站起来,接着镜头向左移动,左边男人拿着一个蓝色的东西,右边男人上前抢夺,两人激烈地来回争抢。”。
|
||||
你需要给出对视频描述的动态内容,不要添加对于静态场景的描述,如果用户输入的描述已经在画面中出现,则移除这些描述
|
||||
改写后的prompt字数控制在100字以下
|
||||
无论用户输入那种语言,你都需要输出中文
|
||||
改写后 prompt 示例:
|
||||
1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。
|
||||
2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。
|
||||
3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。
|
||||
4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。
|
||||
5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。
|
||||
6. 镜头左移后前推,拍摄一个人坐在防波堤上。
|
||||
7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。
|
||||
8. 镜头左移后前推,拍摄一个人坐在防波堤上。
|
||||
9. 带着珍珠项链的女子看向画面右侧并说着话。
|
||||
请直接输出改写后的文本,不要进行多余的回复。'''
|
||||
|
||||
|
||||
I2V_A14B_EN_SYS_PROMPT = \
|
||||
'''You are an expert in rewriting video description prompts. Your task is to rewrite the provided video description prompts based on the images given by users, emphasizing potential dynamic content. Specific requirements are as follows:
|
||||
The user's input language may include diverse descriptions, such as markdown format, instruction format, or be too long or too short. You need to extract the relevant information from the user’s input and associate it with the image content.
|
||||
Your rewritten video description should retain the dynamic parts of the provided prompts, focusing on the main subject's actions. Emphasize and simplify the main subject of the image while retaining their movement. If the user only provides an action (e.g., "dancing"), supplement it reasonably based on the image content (e.g., "a girl is dancing").
|
||||
If the user’s input prompt is too long, refine it to capture the essential action process. If the input is too short, add reasonable motion-related details based on the image content.
|
||||
Retain and emphasize descriptions of camera movements, such as "the camera pans up," "the camera moves from left to right," or "the camera moves from right to left." For example: "The camera captures two men fighting. They start lying on the ground, then the camera moves upward as they stand up. The camera shifts left, showing the man on the left holding a blue object while the man on the right tries to grab it, resulting in a fierce back-and-forth struggle."
|
||||
Focus on dynamic content in the video description and avoid adding static scene descriptions. If the user’s input already describes elements visible in the image, remove those static descriptions.
|
||||
Limit the rewritten prompt to 100 words or less. Regardless of the input language, your output must be in English.
|
||||
|
||||
Examples of rewritten prompts:
|
||||
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
|
||||
A black squirrel focuses on eating, occasionally looking around.
|
||||
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
|
||||
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
|
||||
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
|
||||
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
|
||||
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
|
||||
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
|
||||
A woman wearing a pearl necklace looks to the right and speaks.
|
||||
Output only the rewritten text without additional responses.'''
|
||||
|
||||
|
||||
I2V_A14B_EMPTY_ZH_SYS_PROMPT = \
|
||||
'''你是一个视频描述提示词的撰写专家,你的任务是根据用户给你输入的图像,发挥合理的想象,让这张图动起来,你要强调潜在的动态内容。具体要求如下
|
||||
你需要根据图片的内容想象出运动的主体
|
||||
你输出的结果应强调图片中的动态部分,保留主体的动作。
|
||||
你需要给出对视频描述的动态内容,不要有过多的对于静态场景的描述
|
||||
输出的prompt字数控制在100字以下
|
||||
你需要输出中文
|
||||
prompt 示例:
|
||||
1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。
|
||||
2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。
|
||||
3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。
|
||||
4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。
|
||||
5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。
|
||||
6. 镜头左移后前推,拍摄一个人坐在防波堤上。
|
||||
7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。
|
||||
8. 镜头左移后前推,拍摄一个人坐在防波堤上。
|
||||
9. 带着珍珠项链的女子看向画面右侧并说着话。
|
||||
请直接输出文本,不要进行多余的回复。'''
|
||||
|
||||
|
||||
I2V_A14B_EMPTY_EN_SYS_PROMPT = \
|
||||
'''You are an expert in writing video description prompts. Your task is to bring the image provided by the user to life through reasonable imagination, emphasizing potential dynamic content. Specific requirements are as follows:
|
||||
|
||||
You need to imagine the moving subject based on the content of the image.
|
||||
Your output should emphasize the dynamic parts of the image and retain the main subject’s actions.
|
||||
Focus only on describing dynamic content; avoid excessive descriptions of static scenes.
|
||||
Limit the output prompt to 100 words or less.
|
||||
The output must be in English.
|
||||
|
||||
Prompt examples:
|
||||
|
||||
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
|
||||
A black squirrel focuses on eating, occasionally looking around.
|
||||
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
|
||||
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
|
||||
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
|
||||
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
|
||||
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
|
||||
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
|
||||
A woman wearing a pearl necklace looks to the right and speaks.
|
||||
Output only the text without additional responses.'''
|
||||
233
video/Wan2.2/wan/utils/utils.py
Normal file
233
video/Wan2.2/wan/utils/utils.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
# utils MLX version
|
||||
import argparse
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import imageio
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['save_video', 'save_image', 'str2bool', 'masks_like', 'best_output_size']
|
||||
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
if suffix:
|
||||
if not suffix.startswith('.'):
|
||||
suffix = '.' + suffix
|
||||
name += suffix
|
||||
return name
|
||||
|
||||
|
||||
def make_grid(tensor, nrow=8, normalize=True, value_range=(-1, 1)):
|
||||
"""MLX equivalent of torchvision.utils.make_grid"""
|
||||
# tensor shape: (batch, channels, height, width)
|
||||
batch_size, channels, height, width = tensor.shape
|
||||
|
||||
# Calculate grid dimensions
|
||||
ncol = nrow
|
||||
nrow_actual = (batch_size + ncol - 1) // ncol
|
||||
|
||||
# Create grid
|
||||
grid_height = height * nrow_actual + (nrow_actual - 1) * 2 # 2 pixel padding
|
||||
grid_width = width * ncol + (ncol - 1) * 2
|
||||
|
||||
# Initialize grid with zeros
|
||||
grid = mx.zeros((channels, grid_height, grid_width))
|
||||
|
||||
# Fill grid
|
||||
for idx in range(batch_size):
|
||||
row = idx // ncol
|
||||
col = idx % ncol
|
||||
|
||||
y_start = row * (height + 2)
|
||||
y_end = y_start + height
|
||||
x_start = col * (width + 2)
|
||||
x_end = x_start + width
|
||||
|
||||
img = tensor[idx]
|
||||
if normalize:
|
||||
# Normalize to [0, 1]
|
||||
img = (img - value_range[0]) / (value_range[1] - value_range[0])
|
||||
|
||||
grid[:, y_start:y_end, x_start:x_end] = img
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def save_video(tensor,
|
||||
save_file=None,
|
||||
fps=30,
|
||||
suffix='.mp4',
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)):
|
||||
# cache file
|
||||
cache_file = osp.join('/tmp', rand_name(
|
||||
suffix=suffix)) if save_file is None else save_file
|
||||
|
||||
# save to cache
|
||||
try:
|
||||
# preprocess
|
||||
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
||||
|
||||
# tensor shape: (batch, channels, frames, height, width)
|
||||
# Process each frame
|
||||
frames = []
|
||||
for frame_idx in range(tensor.shape[2]):
|
||||
frame = tensor[:, :, frame_idx, :, :] # (batch, channels, height, width)
|
||||
grid = make_grid(frame, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
frames.append(grid)
|
||||
|
||||
# Stack frames and convert to (frames, height, width, channels)
|
||||
tensor = mx.stack(frames, axis=0) # (frames, channels, height, width)
|
||||
tensor = mx.transpose(tensor, [0, 2, 3, 1]) # (frames, height, width, channels)
|
||||
|
||||
# Convert to uint8
|
||||
tensor = (tensor * 255).astype(mx.uint8)
|
||||
tensor_np = np.array(tensor)
|
||||
|
||||
# write video
|
||||
writer = imageio.get_writer(
|
||||
cache_file, fps=fps, codec='libx264', quality=8)
|
||||
for frame in tensor_np:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except Exception as e:
|
||||
logging.info(f'save_video failed, error: {e}')
|
||||
|
||||
|
||||
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
|
||||
# cache file
|
||||
suffix = osp.splitext(save_file)[1]
|
||||
if suffix.lower() not in [
|
||||
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
||||
]:
|
||||
suffix = '.png'
|
||||
|
||||
# save to cache
|
||||
try:
|
||||
# Clip values
|
||||
tensor = mx.clip(tensor, value_range[0], value_range[1])
|
||||
|
||||
# Make grid
|
||||
grid = make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
|
||||
# Convert to (height, width, channels) and uint8
|
||||
grid = mx.transpose(grid, [1, 2, 0]) # (height, width, channels)
|
||||
grid = (grid * 255).astype(mx.uint8)
|
||||
|
||||
# Save using imageio
|
||||
imageio.imwrite(save_file, np.array(grid))
|
||||
return save_file
|
||||
except Exception as e:
|
||||
logging.info(f'save_image failed, error: {e}')
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""
|
||||
Convert a string to a boolean.
|
||||
|
||||
Supported true values: 'yes', 'true', 't', 'y', '1'
|
||||
Supported false values: 'no', 'false', 'f', 'n', '0'
|
||||
|
||||
Args:
|
||||
v (str): String to convert.
|
||||
|
||||
Returns:
|
||||
bool: Converted boolean value.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
v_lower = v.lower()
|
||||
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
||||
|
||||
|
||||
def masks_like(tensor, zero=False, generator=None, p=0.2):
|
||||
"""
|
||||
Generate masks similar to input tensors.
|
||||
|
||||
Args:
|
||||
tensor: List of MLX arrays
|
||||
zero: Whether to apply zero masking
|
||||
generator: Random generator (for MLX, we use mx.random.seed instead)
|
||||
p: Probability for random masking
|
||||
|
||||
Returns:
|
||||
Tuple of two lists of masks
|
||||
"""
|
||||
assert isinstance(tensor, list)
|
||||
out1 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
|
||||
out2 = [mx.ones(u.shape, dtype=u.dtype) for u in tensor]
|
||||
|
||||
if zero:
|
||||
if generator is not None:
|
||||
# MLX doesn't have the same generator API as PyTorch
|
||||
# We'll use random state instead
|
||||
for u, v in zip(out1, out2):
|
||||
random_num = mx.random.uniform(0, 1, shape=(1,)).item()
|
||||
if random_num < p:
|
||||
# Generate random values with normal distribution
|
||||
normal_vals = mx.random.normal(shape=u[:, 0].shape, loc=-3.5, scale=0.5)
|
||||
u[:, 0] = mx.exp(normal_vals)
|
||||
v[:, 0] = mx.zeros_like(v[:, 0])
|
||||
else:
|
||||
# Keep original values
|
||||
u[:, 0] = u[:, 0]
|
||||
v[:, 0] = v[:, 0]
|
||||
else:
|
||||
for u, v in zip(out1, out2):
|
||||
u[:, 0] = mx.zeros_like(u[:, 0])
|
||||
v[:, 0] = mx.zeros_like(v[:, 0])
|
||||
|
||||
return out1, out2
|
||||
|
||||
|
||||
def best_output_size(w, h, dw, dh, expected_area):
|
||||
"""
|
||||
Calculate the best output size given constraints.
|
||||
|
||||
Args:
|
||||
w: Width
|
||||
h: Height
|
||||
dw: Width divisor
|
||||
dh: Height divisor
|
||||
expected_area: Target area
|
||||
|
||||
Returns:
|
||||
Tuple of (output_width, output_height)
|
||||
"""
|
||||
# float output size
|
||||
ratio = w / h
|
||||
ow = (expected_area * ratio)**0.5
|
||||
oh = expected_area / ow
|
||||
|
||||
# process width first
|
||||
ow1 = int(ow // dw * dw)
|
||||
oh1 = int(expected_area / ow1 // dh * dh)
|
||||
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area
|
||||
ratio1 = ow1 / oh1
|
||||
|
||||
# process height first
|
||||
oh2 = int(oh // dh * dh)
|
||||
ow2 = int(expected_area / oh2 // dw * dw)
|
||||
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area
|
||||
ratio2 = ow2 / oh2
|
||||
|
||||
# compare ratios
|
||||
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,
|
||||
ratio2 / ratio):
|
||||
return ow1, oh1
|
||||
else:
|
||||
return ow2, oh2
|
||||
160
video/Wan2.2/wan/vae_model_io.py
Normal file
160
video/Wan2.2/wan/vae_model_io.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple
|
||||
from safetensors import safe_open
|
||||
|
||||
def convert_pytorch_to_mlx(pytorch_path: str, output_path: str, float16: bool = False):
|
||||
"""
|
||||
Convert PyTorch VAE weights to MLX format with correct mapping.
|
||||
"""
|
||||
print(f"Converting {pytorch_path} -> {output_path}")
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
# Load PyTorch weights
|
||||
if pytorch_path.endswith('.safetensors'):
|
||||
weights = {}
|
||||
with safe_open(pytorch_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensor = f.get_tensor(key)
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
weights[key] = tensor.numpy()
|
||||
else:
|
||||
checkpoint = torch.load(pytorch_path, map_location='cpu')
|
||||
weights = {}
|
||||
state_dict = checkpoint if isinstance(checkpoint, dict) and 'state_dict' not in checkpoint else checkpoint.get('state_dict', checkpoint)
|
||||
for key, tensor in state_dict.items():
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
weights[key] = tensor.numpy()
|
||||
|
||||
# Convert weights
|
||||
mlx_weights = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip these
|
||||
if any(skip in key for skip in ["num_batches_tracked", "running_mean", "running_var"]):
|
||||
continue
|
||||
|
||||
# Convert weight formats
|
||||
if value.ndim == 5 and "weight" in key: # Conv3d weights
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
value = np.transpose(value, (0, 2, 3, 4, 1))
|
||||
elif value.ndim == 4 and "weight" in key: # Conv2d weights
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX Conv2d expects: (out_channels, H, W, in_channels)
|
||||
value = np.transpose(value, (0, 2, 3, 1))
|
||||
elif value.ndim == 1 and "bias" in key: # Conv biases
|
||||
# Keep as is - MLX uses same format
|
||||
pass
|
||||
|
||||
# Map the key
|
||||
new_key = key
|
||||
|
||||
# Map residual block internals within Sequential
|
||||
# PyTorch: encoder.downsamples.0.residual.0.gamma
|
||||
# MLX: encoder.downsamples.layers.0.residual.layers.0.gamma
|
||||
import re
|
||||
|
||||
# Add .layers to Sequential modules
|
||||
new_key = re.sub(r'\.downsamples\.(\d+)', r'.downsamples.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.upsamples\.(\d+)', r'.upsamples.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.middle\.(\d+)', r'.middle.layers.\1', new_key)
|
||||
new_key = re.sub(r'\.head\.(\d+)', r'.head.layers.\1', new_key)
|
||||
|
||||
# Map residual Sequential internals
|
||||
if ".residual." in new_key:
|
||||
match = re.search(r'\.residual\.(\d+)\.', new_key)
|
||||
if match:
|
||||
idx = int(match.group(1))
|
||||
if idx == 0: # First RMS_norm
|
||||
new_key = re.sub(r'\.residual\.0\.', '.residual.layers.0.', new_key)
|
||||
elif idx == 1: # SiLU - skip
|
||||
continue
|
||||
elif idx == 2: # First Conv3d
|
||||
new_key = re.sub(r'\.residual\.2\.', '.residual.layers.2.', new_key)
|
||||
elif idx == 3: # Second RMS_norm
|
||||
new_key = re.sub(r'\.residual\.3\.', '.residual.layers.3.', new_key)
|
||||
elif idx == 4: # Second SiLU - skip
|
||||
continue
|
||||
elif idx == 5: # Dropout - could be Identity in MLX
|
||||
if "Dropout" in key:
|
||||
continue
|
||||
new_key = re.sub(r'\.residual\.5\.', '.residual.layers.5.', new_key)
|
||||
elif idx == 6: # Second Conv3d
|
||||
new_key = re.sub(r'\.residual\.6\.', '.residual.layers.6.', new_key)
|
||||
|
||||
# Map resample internals
|
||||
if ".resample." in new_key:
|
||||
# In both Encoder and Decoder Resample blocks, the Conv2d is at index 1
|
||||
# in the nn.Sequential block, following either a padding or upsample layer.
|
||||
# We just need to map PyTorch's .1 to MLX's .layers.1
|
||||
if ".resample.1." in new_key:
|
||||
new_key = new_key.replace(".resample.1.", ".resample.layers.1.")
|
||||
|
||||
# The layers at index 0 (ZeroPad2d, Upsample) have no weights, so we can
|
||||
# safely skip any keys associated with them.
|
||||
if ".resample.0." in key:
|
||||
continue
|
||||
|
||||
# Map head internals (already using Sequential in MLX)
|
||||
# Just need to handle the layers index
|
||||
|
||||
# Handle shortcut layers
|
||||
if ".shortcut." in new_key and "Identity" not in key:
|
||||
# Shortcut Conv3d layers - keep as is
|
||||
pass
|
||||
elif "Identity" in key:
|
||||
# Skip Identity modules
|
||||
continue
|
||||
|
||||
# Handle time_conv in Resample
|
||||
if "time_conv" in new_key:
|
||||
# Keep as is - already correctly named
|
||||
pass
|
||||
|
||||
# Handle attention layers
|
||||
if "to_qkv" in new_key or "proj" in new_key:
|
||||
# Keep as is - already correctly named
|
||||
pass
|
||||
|
||||
# In the conversion script
|
||||
if "gamma" in new_key:
|
||||
# Squeeze gamma from (C, 1, 1) or (C, 1, 1, 1) to just (C,)
|
||||
value = np.squeeze(value) # This removes all dimensions of size 1
|
||||
# Result will always be 1D array of shape (C,)
|
||||
|
||||
# Add to MLX weights
|
||||
mlx_weights[new_key] = mx.array(value).astype(dtype)
|
||||
|
||||
# Verify critical layers are present
|
||||
critical_prefixes = [
|
||||
"encoder.conv1", "decoder.conv1", "conv1", "conv2",
|
||||
"encoder.head.layers.2", "decoder.head.layers.2" # Updated for Sequential
|
||||
]
|
||||
|
||||
for prefix in critical_prefixes:
|
||||
found = any(k.startswith(prefix) for k in mlx_weights.keys())
|
||||
if not found:
|
||||
print(f"WARNING: No weights found for {prefix}")
|
||||
|
||||
print(f"Converted {len(mlx_weights)} parameters")
|
||||
|
||||
# Print a few example keys for verification
|
||||
print("\nExample converted keys:")
|
||||
for i, key in enumerate(sorted(mlx_weights.keys())[:10]):
|
||||
print(f" {key}")
|
||||
|
||||
# Save
|
||||
if output_path.endswith('.safetensors'):
|
||||
mx.save_safetensors(output_path, mlx_weights)
|
||||
else:
|
||||
mx.savez(output_path, **mlx_weights)
|
||||
|
||||
print(f"\nSaved to {output_path}")
|
||||
print("\nAll converted keys:")
|
||||
for key in sorted(mlx_weights.keys()):
|
||||
print(f" {key}: {mlx_weights[key].shape}")
|
||||
return mlx_weights
|
||||
294
video/Wan2.2/wan/wan_model_io.py
Normal file
294
video/Wan2.2/wan/wan_model_io.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# wan_model_io.py
|
||||
|
||||
from typing import List, Tuple, Set, Dict
|
||||
import os
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten, tree_flatten
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
import numpy as np
|
||||
import glob
|
||||
|
||||
|
||||
def map_wan_2_2_weights(key: str, value: mx.array) -> List[Tuple[str, mx.array]]:
|
||||
"""Map PyTorch WAN 2.2 weights to MLX format."""
|
||||
|
||||
# Only add .layers to Sequential WITHIN components, not to blocks themselves
|
||||
# blocks.N stays as blocks.N (not blocks.layers.N)
|
||||
|
||||
# Handle Sequential layers - PyTorch uses .0, .1, .2, MLX uses .layers.0, .layers.1, .layers.2
|
||||
# Only for components INSIDE blocks and top-level modules
|
||||
if ".ffn." in key and not ".layers." in key:
|
||||
# Replace .ffn.0 with .ffn.layers.0, etc.
|
||||
key = key.replace(".ffn.0.", ".ffn.layers.0.")
|
||||
key = key.replace(".ffn.1.", ".ffn.layers.1.")
|
||||
key = key.replace(".ffn.2.", ".ffn.layers.2.")
|
||||
|
||||
if "text_embedding." in key and not ".layers." in key:
|
||||
for i in range(10):
|
||||
key = key.replace(f"text_embedding.{i}.", f"text_embedding.layers.{i}.")
|
||||
|
||||
if "time_embedding." in key and not ".layers." in key:
|
||||
for i in range(10):
|
||||
key = key.replace(f"time_embedding.{i}.", f"time_embedding.layers.{i}.")
|
||||
|
||||
if "time_projection." in key and not ".layers." in key:
|
||||
for i in range(10):
|
||||
key = key.replace(f"time_projection.{i}.", f"time_projection.layers.{i}.")
|
||||
|
||||
# Handle conv transpose for patch_embedding
|
||||
if "patch_embedding.weight" in key:
|
||||
# PyTorch Conv3d: (out_channels, in_channels, D, H, W)
|
||||
# MLX Conv3d: (out_channels, D, H, W, in_channels)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def check_parameter_mismatch(model, weights: Dict[str, mx.array]) -> Tuple[Set[str], Set[str]]:
|
||||
"""
|
||||
Check for parameter mismatches between model and weights.
|
||||
|
||||
Returns:
|
||||
(model_only, weights_only): Sets of parameter names that exist only in model or weights
|
||||
"""
|
||||
# Get all parameter names from model
|
||||
model_params = dict(tree_flatten(model.parameters()))
|
||||
model_keys = set(model_params.keys())
|
||||
|
||||
# Remove computed buffers that aren't loaded from weights
|
||||
computed_buffers = {'freqs'} # Add any other computed buffers here
|
||||
model_keys = model_keys - computed_buffers
|
||||
|
||||
# Get all parameter names from weights
|
||||
weight_keys = set(weights.keys())
|
||||
|
||||
# Find differences
|
||||
model_only = model_keys - weight_keys
|
||||
weights_only = weight_keys - model_keys
|
||||
|
||||
return model_only, weights_only
|
||||
|
||||
|
||||
def load_wan_2_2_from_safetensors(
|
||||
safetensors_path: str,
|
||||
model,
|
||||
float16: bool = False,
|
||||
check_mismatch: bool = True
|
||||
):
|
||||
"""
|
||||
Load WAN 2.2 Model weights from safetensors file(s) into MLX model.
|
||||
|
||||
Args:
|
||||
safetensors_path: Path to safetensors file or directory
|
||||
model: MLX model instance
|
||||
float16: Whether to use float16 precision
|
||||
check_mismatch: Whether to check for parameter mismatches
|
||||
"""
|
||||
if os.path.isdir(safetensors_path):
|
||||
# Multiple files (14B model) - only diffusion_mlx_model files
|
||||
pattern = os.path.join(safetensors_path, "diffusion_mlx_model*.safetensors")
|
||||
safetensor_files = sorted(glob.glob(pattern))
|
||||
print(f"Found {len(safetensor_files)} diffusion_mlx_model safetensors files")
|
||||
|
||||
# Load all files and merge weights
|
||||
all_weights = {}
|
||||
for file_path in safetensor_files:
|
||||
print(f"Loading: {file_path}")
|
||||
weights = mx.load(file_path)
|
||||
all_weights.update(weights)
|
||||
|
||||
if check_mismatch:
|
||||
model_only, weights_only = check_parameter_mismatch(model, all_weights)
|
||||
|
||||
if model_only:
|
||||
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
|
||||
for param in sorted(model_only)[:10]: # Show first 10
|
||||
print(f" - {param}")
|
||||
if len(model_only) > 10:
|
||||
print(f" ... and {len(model_only) - 10} more")
|
||||
|
||||
if weights_only:
|
||||
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
|
||||
for param in sorted(weights_only)[:10]: # Show first 10
|
||||
print(f" - {param}")
|
||||
if len(weights_only) > 10:
|
||||
print(f" ... and {len(weights_only) - 10} more")
|
||||
|
||||
if not model_only and not weights_only:
|
||||
print("\n✅ Perfect match: All parameters align between model and weights!")
|
||||
|
||||
model.update(tree_unflatten(list(all_weights.items())))
|
||||
else:
|
||||
# Single file
|
||||
print(f"Loading single file: {safetensors_path}")
|
||||
weights = mx.load(safetensors_path)
|
||||
|
||||
if check_mismatch:
|
||||
model_only, weights_only = check_parameter_mismatch(model, weights)
|
||||
|
||||
if model_only:
|
||||
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in weights:")
|
||||
for param in sorted(model_only)[:10]: # Show first 10
|
||||
print(f" - {param}")
|
||||
if len(model_only) > 10:
|
||||
print(f" ... and {len(model_only) - 10} more")
|
||||
|
||||
if weights_only:
|
||||
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in weights but NOT in model:")
|
||||
for param in sorted(weights_only)[:10]: # Show first 10
|
||||
print(f" - {param}")
|
||||
if len(weights_only) > 10:
|
||||
print(f" ... and {len(weights_only) - 10} more")
|
||||
|
||||
if not model_only and not weights_only:
|
||||
print("\n✅ Perfect match: All parameters align between model and weights!")
|
||||
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
print("\nWAN 2.2 Model weights loaded successfully!")
|
||||
return model
|
||||
|
||||
|
||||
def convert_wan_2_2_safetensors_to_mlx(
|
||||
safetensors_path: str,
|
||||
output_path: str,
|
||||
float16: bool = False,
|
||||
model=None # Optional: provide model instance to check parameter alignment
|
||||
):
|
||||
"""
|
||||
Convert WAN 2.2 PyTorch safetensors file to MLX weights file.
|
||||
|
||||
Args:
|
||||
safetensors_path: Input safetensors file
|
||||
output_path: Output MLX weights file (.safetensors)
|
||||
float16: Whether to use float16 precision
|
||||
model: Optional MLX model instance to check parameter alignment
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print(f"Converting WAN 2.2 safetensors to MLX format...")
|
||||
print(f"Input: {safetensors_path}")
|
||||
print(f"Output: {output_path}")
|
||||
print(f"Target dtype: {dtype}")
|
||||
|
||||
# Load and convert weights
|
||||
weights = {}
|
||||
bfloat16_count = 0
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
keys = list(f.keys())
|
||||
print(f"Processing {len(keys)} parameters...")
|
||||
|
||||
for key in keys:
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
bfloat16_count += 1
|
||||
tensor = tensor.float() # Convert to float32 first
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_wan_2_2_weights(key, value)
|
||||
|
||||
for new_key, new_value in mapped:
|
||||
weights[new_key] = new_value
|
||||
|
||||
if bfloat16_count > 0:
|
||||
print(f"⚠️ Converted {bfloat16_count} BFloat16 tensors to {dtype}")
|
||||
|
||||
# Check parameter alignment if model provided
|
||||
if model is not None:
|
||||
print("\nChecking parameter alignment with model...")
|
||||
model_only, weights_only = check_parameter_mismatch(model, weights)
|
||||
|
||||
if model_only:
|
||||
print(f"\n⚠️ WARNING: {len(model_only)} parameters in model but NOT in converted weights:")
|
||||
for param in sorted(model_only)[:10]:
|
||||
print(f" - {param}")
|
||||
if len(model_only) > 10:
|
||||
print(f" ... and {len(model_only) - 10} more")
|
||||
|
||||
if weights_only:
|
||||
print(f"\n⚠️ WARNING: {len(weights_only)} parameters in converted weights but NOT in model:")
|
||||
for param in sorted(weights_only)[:10]:
|
||||
print(f" - {param}")
|
||||
if len(weights_only) > 10:
|
||||
print(f" ... and {len(weights_only) - 10} more")
|
||||
|
||||
if not model_only and not weights_only:
|
||||
print("\n✅ Perfect match: All parameters align between model and converted weights!")
|
||||
|
||||
# Save as MLX format
|
||||
print(f"\nSaving {len(weights)} parameters to: {output_path}")
|
||||
mx.save_safetensors(output_path, weights)
|
||||
|
||||
# Print a few example keys for verification
|
||||
print("\nExample converted keys:")
|
||||
for i, key in enumerate(sorted(weights.keys())[:10]):
|
||||
print(f" {key}: {weights[key].shape}")
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def convert_multiple_wan_2_2_safetensors_to_mlx(
|
||||
checkpoint_dir: str,
|
||||
float16: bool = False
|
||||
):
|
||||
"""Convert multiple WAN 2.2 PyTorch safetensors files to MLX format."""
|
||||
# Find all PyTorch model files
|
||||
pytorch_pattern = os.path.join(checkpoint_dir, "diffusion_pytorch_model-*.safetensors")
|
||||
pytorch_files = sorted(glob.glob(pytorch_pattern))
|
||||
|
||||
if not pytorch_files:
|
||||
raise FileNotFoundError(f"No PyTorch model files found matching: {pytorch_pattern}")
|
||||
|
||||
print(f"Converting {len(pytorch_files)} PyTorch files to MLX format...")
|
||||
|
||||
for i, pytorch_file in enumerate(pytorch_files, 1):
|
||||
# Extract the suffix (e.g., "00001-of-00006")
|
||||
basename = os.path.basename(pytorch_file)
|
||||
suffix = basename.replace("diffusion_pytorch_model-", "").replace(".safetensors", "")
|
||||
|
||||
# Create MLX filename
|
||||
mlx_file = os.path.join(checkpoint_dir, f"diffusion_mlx_model-{suffix}.safetensors")
|
||||
|
||||
print(f"\nConverting {i}/{len(pytorch_files)}: {basename}")
|
||||
convert_wan_2_2_safetensors_to_mlx(pytorch_file, mlx_file, float16)
|
||||
|
||||
print("\nAll files converted successfully!")
|
||||
|
||||
|
||||
def debug_wan_2_2_weight_mapping(safetensors_path: str, float16: bool = False):
|
||||
"""
|
||||
Debug function to see how WAN 2.2 weights are being mapped.
|
||||
"""
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
|
||||
print("=== WAN 2.2 Weight Mapping Debug ===")
|
||||
|
||||
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
|
||||
# Check first 30 keys to see the mapping
|
||||
for i, key in enumerate(f.keys()):
|
||||
if i >= 30:
|
||||
break
|
||||
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
# Handle BFloat16
|
||||
original_dtype = tensor.dtype
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
tensor = tensor.float()
|
||||
|
||||
value = mx.array(tensor.numpy()).astype(dtype)
|
||||
|
||||
# Apply mapping
|
||||
mapped = map_wan_2_2_weights(key, value)
|
||||
|
||||
new_key, new_value = mapped[0]
|
||||
if new_key == key:
|
||||
print(f"UNCHANGED: {key} [{tensor.shape}]")
|
||||
else:
|
||||
print(f"MAPPED: {key} -> {new_key} [{tensor.shape}]")
|
||||
Reference in New Issue
Block a user