Implement Wan2.2

This commit is contained in:
N
2025-07-31 02:30:20 -07:00
parent 4b2a0df237
commit 3b25af07d3
30 changed files with 6217 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae2_1 import Wan2_1_VAE
__all__ = [
'Wan2_1_VAE',
'Wan2_2_VAE',
'WanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
'mlx_attention',
]

View File

@@ -0,0 +1,660 @@
# MLX implementation of model.py
import math
from typing import List, Tuple, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.astype(mx.float32)
# calculation
arange_vals = mx.arange(half).astype(mx.float32)
div_term = mx.power(10000, -arange_vals / half)
sinusoid = position[:, None] @ div_term[None, :]
x = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
return x
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
positions = mx.arange(max_seq_len).astype(mx.float32)
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
freqs = 1.0 / mx.power(theta, freqs)
angles = positions[:, None] @ freqs[None, :]
# Store as [max_seq_len, dim//2, 2] where last dimension is [real, imag]
freqs_complex = mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1)
return freqs_complex
def rope_apply(x, grid_sizes, freqs):
n, c = x.shape[2], x.shape[3] // 2
# split freqs based on dimension allocation
split_sizes = [c - 2 * (c // 3), c // 3, c // 3]
freqs_splits = []
start = 0
for size in split_sizes:
freqs_splits.append(freqs[:, start:start+size, :])
start += size
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# reshape x_i to complex representation
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
# precompute frequency multipliers for each dimension
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)
# Concatenate frequency components
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
# apply rotary embedding (complex multiplication)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
freqs_real = freqs_i[..., 0]
freqs_imag = freqs_i[..., 1]
out_real = x_real * freqs_real - x_imag * freqs_imag
out_imag = x_real * freqs_imag + x_imag * freqs_real
x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)
# Handle remaining sequence
if x.shape[1] > seq_len:
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
output.append(x_i)
return mx.stack(output)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
"""
Args:
x(Array): Shape [B, L, C]
"""
return self._norm(x) * self.weight
def _norm(self, x):
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, affine=False):
super().__init__(dims=dim, eps=eps, affine=affine)
def __call__(self, x):
"""
Args:
x(Array): Shape [B, L, C]
"""
return super().__call__(x)
def mlx_attention(
q: mx.array,
k: mx.array,
v: mx.array,
q_lens: Optional[mx.array] = None,
k_lens: Optional[mx.array] = None,
dropout_p: float = 0.,
softmax_scale: Optional[float] = None,
q_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
dtype: Optional[type] = None,
) -> mx.array:
# Get shapes
b, lq, n, d = q.shape
_, lk, _, _ = k.shape
# Scale queries if needed
if q_scale is not None:
q = q * q_scale
# Compute attention scores
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
# Compute attention scores
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
# Apply softmax scale if provided
if softmax_scale is not None:
scores = scores * softmax_scale
else:
# Default scaling by sqrt(d)
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
# Create attention mask
attn_mask = None
# Apply window size masking if specified
if window_size != (-1, -1):
left_window, right_window = window_size
window_mask = mx.zeros((lq, lk))
for i in range(lq):
start = max(0, i - left_window)
end = min(lk, i + right_window + 1)
window_mask[i, start:end] = 1
attn_mask = window_mask
# Apply causal masking if needed
if causal:
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
if attn_mask is None:
attn_mask = causal_mask
else:
attn_mask = mx.logical_and(attn_mask, causal_mask)
# Apply attention mask if present
if attn_mask is not None:
attn_mask = attn_mask.astype(scores.dtype)
scores = scores * attn_mask + (1 - attn_mask) * -1e4
# Apply attention mask if lengths are provided
if q_lens is not None or k_lens is not None:
if q_lens is not None:
mask = mx.arange(lq)[None, :] < q_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
if k_lens is not None:
mask = mx.arange(lk)[None, :] < k_lens[:, None]
mask = mask.astype(scores.dtype)
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
# Apply softmax
max_scores = mx.max(scores, axis=-1, keepdims=True)
scores = scores - max_scores
exp_scores = mx.exp(scores)
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
attn = exp_scores / (sum_exp + 1e-6)
# Apply dropout if needed
if dropout_p > 0 and not deterministic:
raise NotImplementedError("Dropout not implemented in MLX version")
# Compute output
out = mx.matmul(attn, v) # [b, n, lq, d]
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
return out
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def __call__(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Array): Shape [B, L, C]
seq_lens(Array): Shape [B]
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
"""
b, s, n, d = x.shape[0], x.shape[1], self.num_heads, self.head_dim
# query, key, value function
q = self.norm_q(self.q(x)).reshape(b, s, n, d)
k = self.norm_k(self.k(x)).reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
x = mlx_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.reshape(b, s, -1)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def __call__(self, x, context, context_lens):
"""
Args:
x(Array): Shape [B, L1, C]
context(Array): Shape [B, L2, C]
context_lens(Array): Shape [B]
"""
b, n, d = x.shape[0], self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).reshape(b, -1, n, d)
k = self.norm_k(self.k(context)).reshape(b, -1, n, d)
v = self.v(context).reshape(b, -1, n, d)
# compute attention
x = mlx_attention(q, k, v, k_lens=context_lens)
# output
x = x.reshape(b, -1, self.dim)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
def __call__(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Array): Shape [B, L, C]
e(Array): Shape [B, L1, 6, C]
seq_lens(Array): Shape [B], length of each sequence in batch
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
"""
e = mx.split(self.modulation + e, 6, axis=2)
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2),
seq_lens, grid_sizes, freqs)
x = x + y * mx.squeeze(e[2], axis=2)
# cross-attention & ffn function
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
x = x + y * mx.squeeze(e[5], axis=2)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
def __call__(self, x, e):
"""
Args:
x(Array): Shape [B, L1, C]
e(Array): Shape [B, L1, C]
"""
e = mx.split(self.modulation + mx.expand_dims(e, axis=2), 2, axis=2)
x = self.head(
self.norm(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2))
return x
class WanModel(nn.Module):
"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'ti2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim))
self.time_projection = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * 6))
# blocks
self.blocks = [
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps) for _ in range(num_layers)
]
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], axis=1)
# initialize weights
self.init_weights()
def __call__(
self,
x,
t,
context,
seq_len,
y=None,
):
"""
Forward pass through the diffusion model
Args:
x (List[Array]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Array):
Diffusion timesteps tensor of shape [B]
context (List[Array]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Array], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Array]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert y is not None
if y is not None:
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(mx.expand_dims(mx.transpose(u, (1, 2, 3, 0)), axis=0)) for u in x]
grid_sizes = mx.stack(
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])
x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]
seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
assert seq_lens.max() <= seq_len
# Pad sequences
x_padded = []
for u in x:
pad_len = seq_len - u.shape[1]
if pad_len > 0:
padding = mx.zeros((u.shape[0], pad_len, u.shape[2]))
u = mx.concatenate([u, padding], axis=1)
x_padded.append(u)
x = mx.concatenate(x_padded, axis=0)
# time embeddings
if t.ndim == 1:
t = mx.broadcast_to(t[:, None], (t.shape[0], seq_len))
bt = t.shape[0]
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).reshape(bt, seq_len, -1))
e0 = self.time_projection(e).reshape(bt, seq_len, 6, self.dim)
# context
context_lens = None
context_padded = []
for u in context:
pad_len = self.text_len - u.shape[0]
if pad_len > 0:
padding = mx.zeros((pad_len, u.shape[1]))
u = mx.concatenate([u, padding], axis=0)
context_padded.append(u)
context = self.text_embedding(mx.stack(context_padded))
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def unpatchify(self, x, grid_sizes):
"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Array]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Array):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Array]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for i, v in enumerate(grid_sizes):
v = v.tolist()
seq_len = math.prod(v)
u = x[i, :seq_len].reshape(*v, *self.patch_size, c)
# Rearrange dimensions: (f, h, w, p, q, r, c) -> (c, f*p, h*q, w*r)
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
u = u.reshape(c, v[0] * self.patch_size[0],
v[1] * self.patch_size[1],
v[2] * self.patch_size[2])
out.append(u)
return out
def init_weights(self):
"""
Initialize model parameters using Xavier initialization.
"""
# Initialize patch embedding
fan_in = self.in_dim * math.prod(self.patch_size)
fan_out = self.dim
std = math.sqrt(2.0 / (fan_in + fan_out))
self.patch_embedding.weight = mx.random.uniform(
low=-std, high=std, shape=self.patch_embedding.weight.shape)
# Initialize text embedding layers with normal distribution
text_layers = list(self.text_embedding.layers)
for i in [0, 2]: # First and third layers
layer = text_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)
# Initialize time embedding layers
time_layers = list(self.time_embedding.layers)
for i in [0, 2]: # First and third layers
layer = time_layers[i]
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias = mx.zeros(layer.bias.shape)
# Initialize output head to zeros
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
self.head.head.bias = mx.zeros(self.head.head.bias.shape)

View File

@@ -0,0 +1,616 @@
# MLX implementation for t5.py
import logging
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
class GELU(nn.Module):
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
model = model_cls(**kwargs)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
checkpoint_path=None,
tokenizer_path=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False)
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state

View File

@@ -0,0 +1,82 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ['HuggingfaceTokenizer']
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans('', '', string.punctuation))
for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop('return_mask', False)
# arguments
_kwargs = {'return_tensors': 'pt'}
if self.seq_len is not None:
_kwargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.seq_len
})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text

View File

@@ -0,0 +1,703 @@
# MLX implementation of vae2_1.py
import logging
from typing import Optional, List, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
__all__ = [
'Wan2_1_VAE',
]
CACHE_T = 2
debug_line = 0
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolution for MLX.
Expects input in BTHWC format (batch, time, height, width, channels).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Padding order: (W, W, H, H, T, 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def __call__(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
padding[4] -= cache_x.shape[1]
# Pad in BTHWC format
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
(padding[0], padding[1]), (0, 0)]
x = mx.pad(x, pad_width)
result = super().__call__(x)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=False, images=True, bias=False):
super().__init__()
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
# Just keep as 1D - let broadcasting do its magic
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else 0.
def __call__(self, x):
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
x = x / norm
return x * self.scale * self.gamma + self.bias
class Upsample(nn.Module):
"""
Upsampling layer that matches PyTorch's behavior.
"""
def __init__(self, scale_factor, mode='nearest-exact'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode # mode is now unused, but kept for signature consistency
def __call__(self, x):
scale_h, scale_w = self.scale_factor
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
return out
class AsymmetricPad(nn.Module):
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
def __init__(self, pad_width: tuple):
super().__init__()
self.pad_width = pad_width
def __call__(self, x):
return mx.pad(x, self.pad_width)
# Update your Resample class to use 'nearest-exact'
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
# --- CORRECTED PADDING LOGIC ---
elif mode == 'downsample2d':
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
elif mode == 'downsample3d':
# The spatial downsampling part uses the same logic
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# The __call__ method logic remains unchanged from your original code
b, t, h, w, c = x.shape
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
cache_x = mx.concatenate([
mx.zeros_like(cache_x), cache_x
], axis=1)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, t, h, w, 2, c)
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
x = self.resample(x)
_, h_new, w_new, c_new = x.shape
x = x.reshape(b, t, h_new, w_new, c_new)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, -1:, :, :, :]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
CausalConv3d(out_dim, out_dim, 3, padding=1)
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for i, layer in enumerate(self.residual.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
self.proj.weight = mx.zeros_like(self.proj.weight)
def __call__(self, x):
# x is in BTHWC format
identity = x
b, t, h, w, c = x.shape
x = x.reshape(b * t, h, w, c) # Combine batch and time
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
qkv = qkv.reshape(b * t, h * w, 3 * c)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape for attention
q = q.reshape(b * t, h * w, c)
k = k.reshape(b * t, h * w, c)
v = v.reshape(b * t, h * w, c)
# Scaled dot product attention
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
scores = (q @ k.transpose(0, 2, 1)) * scale
weights = mx.softmax(scores, axis=-1)
x = weights @ v
x = x.reshape(b * t, h, w, c)
# output
x = self.proj(x)
x = x.reshape(b, t, h, w, c)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[-1], dims[-1], dropout),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1], dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], z_dim, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for i, layer in enumerate(self.downsamples.layers):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], 3, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples.layers:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for name, module in model.named_modules():
if isinstance(module, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, scale):
# x is in BTHWC format
self.clear_cache()
## cache
t = x.shape[1]
iter_ = 1 + (t - 1) // 4
## Split encode input x by time into 1, 4, 4, 4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :1, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = mx.concatenate([out, out_], axis=1)
z = self.conv1(out)
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
if isinstance(scale[0], mx.array):
# Reshape scale for broadcasting in BTHWC format
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
mu = (mu - scale_mean) * scale_std
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu, log_var
def decode(self, z, scale):
# z is in BTHWC format
self.clear_cache()
if isinstance(scale[0], mx.array):
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
z = z / scale_std + scale_mean
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[1]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = mx.concatenate([out, out_], axis=1)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = mx.exp(0.5 * log_var)
eps = mx.random.normal(std.shape)
return eps * std + mu
def __call__(self, x):
mu, log_var = self.encode(x, self.scale)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, self.scale)
return x_recon, mu, log_var
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs, self.scale)
if deterministic:
return mu
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
return mu + std * mx.random.normal(std.shape)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
model = WanVAE_(**cfg)
# load checkpoint
if pretrained_path:
logging.info(f'loading {pretrained_path}')
weights = mx.load(pretrained_path)
model.update(tree_unflatten(list(weights.items())))
return model
class Wan2_1_VAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=mx.float32):
self.dtype = dtype
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = mx.array(mean, dtype=dtype)
self.std = mx.array(std, dtype=dtype)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
Returns: List of encoded videos in [C, T, H, W] format.
"""
encoded = []
for video in videos:
# Convert CTHW -> BTHWC
x = mx.expand_dims(video, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Encode
z = self.model.encode(x, self.scale)[0] # Get mu only
# Convert back BTHWC -> CTHW and remove batch dimension
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
z = z.squeeze(0) # Remove batch dimension -> CTHW
encoded.append(z.astype(mx.float32))
return encoded
def decode(self, zs):
"""
zs: A list of latent codes each with shape [C, T, H, W].
Returns: List of decoded videos in [C, T, H, W] format.
"""
decoded = []
for z in zs:
# Convert CTHW -> BTHWC
x = mx.expand_dims(z, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Decode
x = self.model.decode(x, self.scale)
# Convert back BTHWC -> CTHW and remove batch dimension
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
x = x.squeeze(0) # Remove batch dimension -> CTHW
# Clamp values
x = mx.clip(x, -1, 1)
decoded.append(x.astype(mx.float32))
return decoded