mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
17
video/Wan2.2/wan/modules/__init__.py
Normal file
17
video/Wan2.2/wan/modules/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vae2_1 import Wan2_1_VAE
|
||||
|
||||
__all__ = [
|
||||
'Wan2_1_VAE',
|
||||
'Wan2_2_VAE',
|
||||
'WanModel',
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
'HuggingfaceTokenizer',
|
||||
'mlx_attention',
|
||||
]
|
||||
660
video/Wan2.2/wan/modules/model.py
Normal file
660
video/Wan2.2/wan/modules/model.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# MLX implementation of model.py
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.astype(mx.float32)
|
||||
|
||||
# calculation
|
||||
arange_vals = mx.arange(half).astype(mx.float32)
|
||||
div_term = mx.power(10000, -arange_vals / half)
|
||||
sinusoid = position[:, None] @ div_term[None, :]
|
||||
x = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
return x
|
||||
|
||||
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
positions = mx.arange(max_seq_len).astype(mx.float32)
|
||||
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
|
||||
freqs = 1.0 / mx.power(theta, freqs)
|
||||
angles = positions[:, None] @ freqs[None, :]
|
||||
# Store as [max_seq_len, dim//2, 2] where last dimension is [real, imag]
|
||||
freqs_complex = mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1)
|
||||
return freqs_complex
|
||||
|
||||
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.shape[2], x.shape[3] // 2
|
||||
|
||||
# split freqs based on dimension allocation
|
||||
split_sizes = [c - 2 * (c // 3), c // 3, c // 3]
|
||||
freqs_splits = []
|
||||
start = 0
|
||||
for size in split_sizes:
|
||||
freqs_splits.append(freqs[:, start:start+size, :])
|
||||
start += size
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# reshape x_i to complex representation
|
||||
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
|
||||
|
||||
# precompute frequency multipliers for each dimension
|
||||
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
|
||||
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
|
||||
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
|
||||
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)
|
||||
|
||||
# Concatenate frequency components
|
||||
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
|
||||
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
|
||||
|
||||
# apply rotary embedding (complex multiplication)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
freqs_real = freqs_i[..., 0]
|
||||
freqs_imag = freqs_i[..., 1]
|
||||
|
||||
out_real = x_real * freqs_real - x_imag * freqs_imag
|
||||
out_imag = x_real * freqs_imag + x_imag * freqs_real
|
||||
|
||||
x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)
|
||||
|
||||
# Handle remaining sequence
|
||||
if x.shape[1] > seq_len:
|
||||
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
|
||||
|
||||
output.append(x_i)
|
||||
|
||||
return mx.stack(output)
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, dim, eps=1e-6, affine=False):
|
||||
super().__init__(dims=dim, eps=eps, affine=affine)
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
"""
|
||||
return super().__call__(x)
|
||||
|
||||
|
||||
def mlx_attention(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
v: mx.array,
|
||||
q_lens: Optional[mx.array] = None,
|
||||
k_lens: Optional[mx.array] = None,
|
||||
dropout_p: float = 0.,
|
||||
softmax_scale: Optional[float] = None,
|
||||
q_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
deterministic: bool = False,
|
||||
dtype: Optional[type] = None,
|
||||
) -> mx.array:
|
||||
# Get shapes
|
||||
b, lq, n, d = q.shape
|
||||
_, lk, _, _ = k.shape
|
||||
|
||||
# Scale queries if needed
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
# Compute attention scores
|
||||
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
|
||||
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||||
|
||||
# Compute attention scores
|
||||
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
|
||||
|
||||
# Apply softmax scale if provided
|
||||
if softmax_scale is not None:
|
||||
scores = scores * softmax_scale
|
||||
else:
|
||||
# Default scaling by sqrt(d)
|
||||
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = None
|
||||
|
||||
# Apply window size masking if specified
|
||||
if window_size != (-1, -1):
|
||||
left_window, right_window = window_size
|
||||
window_mask = mx.zeros((lq, lk))
|
||||
for i in range(lq):
|
||||
start = max(0, i - left_window)
|
||||
end = min(lk, i + right_window + 1)
|
||||
window_mask[i, start:end] = 1
|
||||
attn_mask = window_mask
|
||||
|
||||
# Apply causal masking if needed
|
||||
if causal:
|
||||
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
|
||||
if attn_mask is None:
|
||||
attn_mask = causal_mask
|
||||
else:
|
||||
attn_mask = mx.logical_and(attn_mask, causal_mask)
|
||||
|
||||
# Apply attention mask if present
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.astype(scores.dtype)
|
||||
scores = scores * attn_mask + (1 - attn_mask) * -1e4
|
||||
|
||||
# Apply attention mask if lengths are provided
|
||||
if q_lens is not None or k_lens is not None:
|
||||
if q_lens is not None:
|
||||
mask = mx.arange(lq)[None, :] < q_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
|
||||
if k_lens is not None:
|
||||
mask = mx.arange(lk)[None, :] < k_lens[:, None]
|
||||
mask = mask.astype(scores.dtype)
|
||||
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
|
||||
|
||||
# Apply softmax
|
||||
max_scores = mx.max(scores, axis=-1, keepdims=True)
|
||||
scores = scores - max_scores
|
||||
exp_scores = mx.exp(scores)
|
||||
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
|
||||
attn = exp_scores / (sum_exp + 1e-6)
|
||||
|
||||
# Apply dropout if needed
|
||||
if dropout_p > 0 and not deterministic:
|
||||
raise NotImplementedError("Dropout not implemented in MLX version")
|
||||
|
||||
# Compute output
|
||||
out = mx.matmul(attn, v) # [b, n, lq, d]
|
||||
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
|
||||
|
||||
return out
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def __call__(self, x, seq_lens, grid_sizes, freqs):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
seq_lens(Array): Shape [B]
|
||||
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||||
"""
|
||||
b, s, n, d = x.shape[0], x.shape[1], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
q = self.norm_q(self.q(x)).reshape(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).reshape(b, s, n, d)
|
||||
v = self.v(x).reshape(b, s, n, d)
|
||||
|
||||
x = mlx_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, s, -1)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanCrossAttention(WanSelfAttention):
|
||||
|
||||
def __call__(self, x, context, context_lens):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L1, C]
|
||||
context(Array): Shape [B, L2, C]
|
||||
context_lens(Array): Shape [B]
|
||||
"""
|
||||
b, n, d = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).reshape(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).reshape(b, -1, n, d)
|
||||
v = self.v(context).reshape(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = mlx_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, self.dim)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||||
eps)
|
||||
self.norm3 = WanLayerNorm(
|
||||
dim, eps,
|
||||
affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
|
||||
eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(ffn_dim, dim))
|
||||
|
||||
# modulation
|
||||
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L, C]
|
||||
e(Array): Shape [B, L1, 6, C]
|
||||
seq_lens(Array): Shape [B], length of each sequence in batch
|
||||
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||||
"""
|
||||
e = mx.split(self.modulation + e, 6, axis=2)
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2),
|
||||
seq_lens, grid_sizes, freqs)
|
||||
x = x + y * mx.squeeze(e[2], axis=2)
|
||||
|
||||
# cross-attention & ffn function
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(
|
||||
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
|
||||
x = x + y * mx.squeeze(e[5], axis=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
|
||||
|
||||
def __call__(self, x, e):
|
||||
"""
|
||||
Args:
|
||||
x(Array): Shape [B, L1, C]
|
||||
e(Array): Shape [B, L1, C]
|
||||
"""
|
||||
e = mx.split(self.modulation + mx.expand_dims(e, axis=2), 2, axis=2)
|
||||
x = self.head(
|
||||
self.norm(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2))
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6):
|
||||
"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v', 'ti2v']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, dim))
|
||||
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(freq_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
self.blocks = [
|
||||
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
|
||||
cross_attn_norm, eps) for _ in range(num_layers)
|
||||
]
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# buffers
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6))
|
||||
], axis=1)
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
y=None,
|
||||
):
|
||||
"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Array]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Array):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Array]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
y (List[Array], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Array]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert y is not None
|
||||
|
||||
if y is not None:
|
||||
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(mx.expand_dims(mx.transpose(u, (1, 2, 3, 0)), axis=0)) for u in x]
|
||||
|
||||
grid_sizes = mx.stack(
|
||||
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])
|
||||
|
||||
|
||||
x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]
|
||||
|
||||
seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
|
||||
assert seq_lens.max() <= seq_len
|
||||
|
||||
# Pad sequences
|
||||
x_padded = []
|
||||
for u in x:
|
||||
pad_len = seq_len - u.shape[1]
|
||||
if pad_len > 0:
|
||||
padding = mx.zeros((u.shape[0], pad_len, u.shape[2]))
|
||||
u = mx.concatenate([u, padding], axis=1)
|
||||
x_padded.append(u)
|
||||
x = mx.concatenate(x_padded, axis=0)
|
||||
|
||||
# time embeddings
|
||||
if t.ndim == 1:
|
||||
t = mx.broadcast_to(t[:, None], (t.shape[0], seq_len))
|
||||
bt = t.shape[0]
|
||||
t = t.flatten()
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).reshape(bt, seq_len, -1))
|
||||
e0 = self.time_projection(e).reshape(bt, seq_len, 6, self.dim)
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context_padded = []
|
||||
for u in context:
|
||||
pad_len = self.text_len - u.shape[0]
|
||||
if pad_len > 0:
|
||||
padding = mx.zeros((pad_len, u.shape[1]))
|
||||
u = mx.concatenate([u, padding], axis=0)
|
||||
context_padded.append(u)
|
||||
context = self.text_embedding(mx.stack(context_padded))
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Array]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Array):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Array]:
|
||||
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for i, v in enumerate(grid_sizes):
|
||||
v = v.tolist()
|
||||
seq_len = math.prod(v)
|
||||
u = x[i, :seq_len].reshape(*v, *self.patch_size, c)
|
||||
# Rearrange dimensions: (f, h, w, p, q, r, c) -> (c, f*p, h*q, w*r)
|
||||
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
|
||||
u = u.reshape(c, v[0] * self.patch_size[0],
|
||||
v[1] * self.patch_size[1],
|
||||
v[2] * self.patch_size[2])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize model parameters using Xavier initialization.
|
||||
"""
|
||||
# Initialize patch embedding
|
||||
fan_in = self.in_dim * math.prod(self.patch_size)
|
||||
fan_out = self.dim
|
||||
std = math.sqrt(2.0 / (fan_in + fan_out))
|
||||
self.patch_embedding.weight = mx.random.uniform(
|
||||
low=-std, high=std, shape=self.patch_embedding.weight.shape)
|
||||
|
||||
# Initialize text embedding layers with normal distribution
|
||||
text_layers = list(self.text_embedding.layers)
|
||||
for i in [0, 2]: # First and third layers
|
||||
layer = text_layers[i]
|
||||
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
layer.bias = mx.zeros(layer.bias.shape)
|
||||
|
||||
# Initialize time embedding layers
|
||||
time_layers = list(self.time_embedding.layers)
|
||||
for i in [0, 2]: # First and third layers
|
||||
layer = time_layers[i]
|
||||
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
layer.bias = mx.zeros(layer.bias.shape)
|
||||
|
||||
# Initialize output head to zeros
|
||||
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
|
||||
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
|
||||
self.head.head.bias = mx.zeros(self.head.head.bias.shape)
|
||||
616
video/Wan2.2/wan/modules/t5.py
Normal file
616
video/Wan2.2/wan/modules/t5.py
Normal file
@@ -0,0 +1,616 @@
|
||||
# MLX implementation for t5.py
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False)
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
||||
82
video/Wan2.2/wan/modules/tokenizers.py
Normal file
82
video/Wan2.2/wan/modules/tokenizers.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import html
|
||||
import string
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
__all__ = ['HuggingfaceTokenizer']
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
703
video/Wan2.2/wan/modules/vae2_1.py
Normal file
703
video/Wan2.2/wan/modules/vae2_1.py
Normal file
@@ -0,0 +1,703 @@
|
||||
# MLX implementation of vae2_1.py
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
__all__ = [
|
||||
'Wan2_1_VAE',
|
||||
]
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolution for MLX.
|
||||
Expects input in BTHWC format (batch, time, height, width, channels).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Padding order: (W, W, H, H, T, 0)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
|
||||
padding[4] -= cache_x.shape[1]
|
||||
|
||||
# Pad in BTHWC format
|
||||
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
|
||||
(padding[0], padding[1]), (0, 0)]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
result = super().__call__(x)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=False, images=True, bias=False):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.images = images
|
||||
self.scale = dim**0.5
|
||||
|
||||
# Just keep as 1D - let broadcasting do its magic
|
||||
self.gamma = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,)) if bias else 0.
|
||||
|
||||
def __call__(self, x):
|
||||
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
|
||||
x = x / norm
|
||||
return x * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Upsampling layer that matches PyTorch's behavior.
|
||||
"""
|
||||
def __init__(self, scale_factor, mode='nearest-exact'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode # mode is now unused, but kept for signature consistency
|
||||
|
||||
def __call__(self, x):
|
||||
scale_h, scale_w = self.scale_factor
|
||||
|
||||
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
|
||||
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
|
||||
|
||||
return out
|
||||
|
||||
class AsymmetricPad(nn.Module):
|
||||
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
|
||||
def __init__(self, pad_width: tuple):
|
||||
super().__init__()
|
||||
self.pad_width = pad_width
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(x, self.pad_width)
|
||||
|
||||
# Update your Resample class to use 'nearest-exact'
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
# --- CORRECTED PADDING LOGIC ---
|
||||
elif mode == 'downsample2d':
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
elif mode == 'downsample3d':
|
||||
# The spatial downsampling part uses the same logic
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# The __call__ method logic remains unchanged from your original code
|
||||
b, t, h, w, c = x.shape
|
||||
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
mx.zeros_like(cache_x), cache_x
|
||||
], axis=1)
|
||||
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, t, h, w, 2, c)
|
||||
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
|
||||
x = x.reshape(b, t * 2, h, w, c)
|
||||
|
||||
t = x.shape[1]
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
_, h_new, w_new, c_new = x.shape
|
||||
x = x.reshape(b, t, h_new, w_new, c_new)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -1:, :, :, :]
|
||||
x = self.time_conv(
|
||||
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
|
||||
for i, layer in enumerate(self.residual.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
self.proj.weight = mx.zeros_like(self.proj.weight)
|
||||
|
||||
def __call__(self, x):
|
||||
# x is in BTHWC format
|
||||
identity = x
|
||||
b, t, h, w, c = x.shape
|
||||
x = x.reshape(b * t, h, w, c) # Combine batch and time
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
|
||||
qkv = qkv.reshape(b * t, h * w, 3 * c)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.reshape(b * t, h * w, c)
|
||||
k = k.reshape(b * t, h * w, c)
|
||||
v = v.reshape(b * t, h * w, c)
|
||||
|
||||
# Scaled dot product attention
|
||||
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
|
||||
scores = (q @ k.transpose(0, 2, 1)) * scale
|
||||
weights = mx.softmax(scores, axis=-1)
|
||||
x = weights @ v
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = x.reshape(b, t, h, w, c)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[-1], dims[-1], dropout),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1], dropout)
|
||||
)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples.layers):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout)
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples.layers:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x, scale):
|
||||
# x is in BTHWC format
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[1]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## Split encode input x by time into 1, 4, 4, 4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :1, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
|
||||
z = self.conv1(out)
|
||||
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
|
||||
|
||||
if isinstance(scale[0], mx.array):
|
||||
# Reshape scale for broadcasting in BTHWC format
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
mu = (mu - scale_mean) * scale_std
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z, scale):
|
||||
# z is in BTHWC format
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], mx.array):
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
z = z / scale_std + scale_mean
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[1]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = mx.exp(0.5 * log_var)
|
||||
eps = mx.random.normal(std.shape)
|
||||
return eps * std + mu
|
||||
|
||||
def __call__(self, x):
|
||||
mu, log_var = self.encode(x, self.scale)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z, self.scale)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs, self.scale)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
|
||||
return mu + std * mx.random.normal(std.shape)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
|
||||
# params
|
||||
cfg = dict(
|
||||
dim=96,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if pretrained_path:
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
weights = mx.load(pretrained_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Wan2_1_VAE:
|
||||
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=mx.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = mx.array(mean, dtype=dtype)
|
||||
self.std = mx.array(std, dtype=dtype)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
Returns: List of encoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
encoded = []
|
||||
for video in videos:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(video, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Encode
|
||||
z = self.model.encode(x, self.scale)[0] # Get mu only
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
z = z.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
encoded.append(z.astype(mx.float32))
|
||||
|
||||
return encoded
|
||||
|
||||
def decode(self, zs):
|
||||
"""
|
||||
zs: A list of latent codes each with shape [C, T, H, W].
|
||||
Returns: List of decoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
decoded = []
|
||||
for z in zs:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(z, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Decode
|
||||
x = self.model.decode(x, self.scale)
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
x = x.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
# Clamp values
|
||||
x = mx.clip(x, -1, 1)
|
||||
|
||||
decoded.append(x.astype(mx.float32))
|
||||
|
||||
return decoded
|
||||
Reference in New Issue
Block a user