diff --git a/video/Wan2.1/README.md b/video/Wan2.1/README.md index 5ecf125a..39568918 100644 --- a/video/Wan2.1/README.md +++ b/video/Wan2.1/README.md @@ -66,7 +66,7 @@ This repository currently supports two Text-to-Video models (1.3B and 14B) and t ##### (1) Example: ``` -python generate_mlx.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4 +python generate.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4 ``` diff --git a/video/Wan2.1/assets/comp_effic.png b/video/Wan2.1/assets/comp_effic.png deleted file mode 100644 index ea0e3b23..00000000 Binary files a/video/Wan2.1/assets/comp_effic.png and /dev/null differ diff --git a/video/Wan2.1/assets/data_for_diff_stage.jpg b/video/Wan2.1/assets/data_for_diff_stage.jpg deleted file mode 100644 index af98046e..00000000 Binary files a/video/Wan2.1/assets/data_for_diff_stage.jpg and /dev/null differ diff --git a/video/Wan2.1/assets/i2v_res.png b/video/Wan2.1/assets/i2v_res.png deleted file mode 100644 index fb13d613..00000000 Binary files a/video/Wan2.1/assets/i2v_res.png and /dev/null differ diff --git a/video/Wan2.1/assets/t2v_res.jpg b/video/Wan2.1/assets/t2v_res.jpg deleted file mode 100644 index 6a583889..00000000 Binary files a/video/Wan2.1/assets/t2v_res.jpg and /dev/null differ diff --git a/video/Wan2.1/assets/vben_vs_sota.png b/video/Wan2.1/assets/vben_vs_sota.png deleted file mode 100644 index 4f09de66..00000000 Binary files a/video/Wan2.1/assets/vben_vs_sota.png and /dev/null differ diff --git a/video/Wan2.1/assets/video_dit_arch.jpg b/video/Wan2.1/assets/video_dit_arch.jpg deleted file mode 100644 index a13e499a..00000000 Binary files a/video/Wan2.1/assets/video_dit_arch.jpg and /dev/null differ diff --git a/video/Wan2.1/assets/video_vae_res.jpg b/video/Wan2.1/assets/video_vae_res.jpg deleted file mode 100644 index e1bfb111..00000000 Binary files a/video/Wan2.1/assets/video_vae_res.jpg and /dev/null differ diff --git a/video/Wan2.1/generate_mlx.py b/video/Wan2.1/generate.py similarity index 99% rename from video/Wan2.1/generate_mlx.py rename to video/Wan2.1/generate.py index ead1e846..ebad359f 100644 --- a/video/Wan2.1/generate_mlx.py +++ b/video/Wan2.1/generate.py @@ -10,7 +10,7 @@ from PIL import Image import wan from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES -from wan.utils.utils_mlx_own import cache_video, cache_image, str2bool +from wan.utils.utils import cache_video, cache_image, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { diff --git a/video/Wan2.1/tests/README.md b/video/Wan2.1/tests/README.md deleted file mode 100644 index addeda72..00000000 --- a/video/Wan2.1/tests/README.md +++ /dev/null @@ -1,6 +0,0 @@ - -Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use. - -```bash -bash ./test.sh -``` diff --git a/video/Wan2.1/tests/test.sh b/video/Wan2.1/tests/test.sh deleted file mode 100644 index bf40cd7c..00000000 --- a/video/Wan2.1/tests/test.sh +++ /dev/null @@ -1,113 +0,0 @@ -#!/bin/bash - - -if [ "$#" -eq 2 ]; then - MODEL_DIR=$(realpath "$1") - GPUS=$2 -else - echo "Usage: $0 " - exit 1 -fi - -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -REPO_ROOT="$(dirname "$SCRIPT_DIR")" -cd "$REPO_ROOT" || exit 1 - -PY_FILE=./generate.py - - -function t2v_1_3B() { - T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B" - - # 1-GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: " - python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR - - # Multiple GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS - - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" - - if [ -n "${DASH_API_KEY+x}" ]; then - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" - else - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." - fi -} - -function t2v_14B() { - T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" - - # 1-GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: " - python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR - - # Multiple GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS - - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" -} - - - -function t2i_14B() { - T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" - - # 1-GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: " - python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR - - # Multiple GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS - - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" -} - - -function i2v_14B_480p() { - I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P" - - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " - python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR - - # Multiple GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS - - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en" - - if [ -n "${DASH_API_KEY+x}" ]; then - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" - else - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." - fi -} - - -function i2v_14B_720p() { - I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P" - - # 1-GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " - python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR - - # Multiple GPU Test - echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " - torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS -} - - -t2i_14B -t2v_1_3B -t2v_14B -i2v_14B_480p -i2v_14B_720p diff --git a/video/Wan2.1/wan/__init__.py b/video/Wan2.1/wan/__init__.py index 901ccdd9..3296267d 100644 --- a/video/Wan2.1/wan/__init__.py +++ b/video/Wan2.1/wan/__init__.py @@ -1,2 +1,2 @@ from . import configs, modules -from .text2video_mlx import WanT2V +from .text2video import WanT2V diff --git a/video/Wan2.1/wan/modules/__init__.py b/video/Wan2.1/wan/modules/__init__.py index 0d8e71e6..1ba3b7b3 100644 --- a/video/Wan2.1/wan/modules/__init__.py +++ b/video/Wan2.1/wan/modules/__init__.py @@ -1,7 +1,7 @@ -from .model_mlx import WanModel -from .t5_mlx import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .model import WanModel +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer -from .vae_mlx import WanVAE +from .vae import WanVAE __all__ = [ 'WanVAE', diff --git a/video/Wan2.1/wan/modules/model_mlx.py b/video/Wan2.1/wan/modules/model.py similarity index 100% rename from video/Wan2.1/wan/modules/model_mlx.py rename to video/Wan2.1/wan/modules/model.py diff --git a/video/Wan2.1/wan/modules/t5.py b/video/Wan2.1/wan/modules/t5.py index 3a26ed18..735e2c71 100644 --- a/video/Wan2.1/wan/modules/t5.py +++ b/video/Wan2.1/wan/modules/t5.py @@ -1,11 +1,13 @@ -# Modified from transformers.models.t5.modeling_t5 +# Modified from transformers.models.t5.modeling_t5 for MLX # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging import math +from typing import Optional, Tuple, List -import torch -import torch.nn as nn -import torch.nn.functional as F +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_unflatten from .tokenizers import HuggingfaceTokenizer @@ -18,59 +20,41 @@ __all__ = [ def fp16_clamp(x): - if x.dtype == torch.float16 and torch.isinf(x).any(): - clamp = torch.finfo(x.dtype).max - 1000 - x = torch.clamp(x, min=-clamp, max=clamp) + if x.dtype == mx.float16: + # Use same clamping as PyTorch for consistency + clamp = 65504.0 # max value for float16 + return mx.clip(x, -clamp, clamp) return x -def init_weights(m): - if isinstance(m, T5LayerNorm): - nn.init.ones_(m.weight) - elif isinstance(m, T5Model): - nn.init.normal_(m.token_embedding.weight, std=1.0) - elif isinstance(m, T5FeedForward): - nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) - nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) - nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) - elif isinstance(m, T5Attention): - nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) - nn.init.normal_(m.k.weight, std=m.dim**-0.5) - nn.init.normal_(m.v.weight, std=m.dim**-0.5) - nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) - elif isinstance(m, T5RelativeEmbedding): - nn.init.normal_( - m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) - - class GELU(nn.Module): - - def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + def __call__(self, x): + return 0.5 * x * (1.0 + mx.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0)))) class T5LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-6): - super(T5LayerNorm, self).__init__() + super().__init__() self.dim = dim self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) + self.weight = mx.ones((dim,)) - def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + - self.eps) - if self.weight.dtype in [torch.float16, torch.float32]: - x = x.type_as(self.weight) - return self.weight * x + def __call__(self, x): + # Match PyTorch's approach: convert to float32 for stability + x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x + variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True) + x_norm = x_float * mx.rsqrt(variance + self.eps) + # Convert back to original dtype + if x.dtype == mx.float16: + x_norm = x_norm.astype(mx.float16) + return self.weight * x_norm class T5Attention(nn.Module): - - def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + def __init__(self, dim, dim_attn, num_heads, dropout=0.0): assert dim_attn % num_heads == 0 - super(T5Attention, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.num_heads = num_heads @@ -83,7 +67,7 @@ class T5Attention(nn.Module): self.o = nn.Linear(dim_attn, dim, bias=False) self.dropout = nn.Dropout(dropout) - def forward(self, x, context=None, mask=None, pos_bias=None): + def __call__(self, x, context=None, mask=None, pos_bias=None): """ x: [B, L1, C]. context: [B, L2, C] or None. @@ -91,50 +75,72 @@ class T5Attention(nn.Module): """ # check inputs context = x if context is None else context - b, n, c = x.size(0), self.num_heads, self.head_dim + b, l1, _ = x.shape + _, l2, _ = context.shape + n, c = self.num_heads, self.head_dim # compute query, key, value - q = self.q(x).view(b, -1, n, c) - k = self.k(context).view(b, -1, n, c) - v = self.v(context).view(b, -1, n, c) + q = self.q(x).reshape(b, l1, n, c) + k = self.k(context).reshape(b, l2, n, c) + v = self.v(context).reshape(b, l2, n, c) - # attention bias - attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) - if pos_bias is not None: - attn_bias += pos_bias - if mask is not None: - assert mask.ndim in [2, 3] - mask = mask.view(b, 1, 1, - -1) if mask.ndim == 2 else mask.unsqueeze(1) - attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + # transpose for attention: [B, N, L, C] + q = mx.transpose(q, (0, 2, 1, 3)) + k = mx.transpose(k, (0, 2, 1, 3)) + v = mx.transpose(v, (0, 2, 1, 3)) # compute attention (T5 does not use scaling) - attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) - x = torch.einsum('bnij,bjnc->binc', attn, v) + attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2] + + # add position bias if provided + if pos_bias is not None: + attn = attn + pos_bias + + # apply mask + if mask is not None: + if mask.ndim == 2: + # [B, L2] -> [B, 1, 1, L2] + mask = mask[:, None, None, :] + elif mask.ndim == 3: + # [B, L1, L2] -> [B, 1, L1, L2] + mask = mask[:, None, :, :] + # Use very negative value that works well with float16 + min_value = -65504.0 if attn.dtype == mx.float16 else -1e9 + attn = mx.where(mask == 0, min_value, attn) - # output - x = x.reshape(b, -1, n * c) + # softmax and apply attention + attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype) + attn = self.dropout(attn) + + # apply attention to values + x = mx.matmul(attn, v) # [B, N, L1, C] + + # transpose back and reshape + x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C] + x = x.reshape(b, l1, -1) + + # output projection x = self.o(x) x = self.dropout(x) return x class T5FeedForward(nn.Module): - - def __init__(self, dim, dim_ffn, dropout=0.1): - super(T5FeedForward, self).__init__() + def __init__(self, dim, dim_ffn, dropout=0.0): + super().__init__() self.dim = dim self.dim_ffn = dim_ffn # layers - self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.gate_proj = nn.Linear(dim, dim_ffn, bias=False) + self.gate_act = GELU() self.fc1 = nn.Linear(dim, dim_ffn, bias=False) self.fc2 = nn.Linear(dim_ffn, dim, bias=False) self.dropout = nn.Dropout(dropout) - def forward(self, x): - x = self.fc1(x) * self.gate(x) + def __call__(self, x): + gate = self.gate_act(self.gate_proj(x)) + x = self.fc1(x) * gate x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) @@ -142,7 +148,6 @@ class T5FeedForward(nn.Module): class T5SelfAttention(nn.Module): - def __init__(self, dim, dim_attn, @@ -150,8 +155,8 @@ class T5SelfAttention(nn.Module): num_heads, num_buckets, shared_pos=True, - dropout=0.1): - super(T5SelfAttention, self).__init__() + dropout=0.0): + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -167,16 +172,15 @@ class T5SelfAttention(nn.Module): self.pos_embedding = None if shared_pos else T5RelativeEmbedding( num_buckets, num_heads, bidirectional=True) - def forward(self, x, mask=None, pos_bias=None): + def __call__(self, x, mask=None, pos_bias=None): e = pos_bias if self.shared_pos else self.pos_embedding( - x.size(1), x.size(1)) + x.shape[1], x.shape[1]) x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.ffn(self.norm2(x))) return x class T5CrossAttention(nn.Module): - def __init__(self, dim, dim_attn, @@ -184,8 +188,8 @@ class T5CrossAttention(nn.Module): num_heads, num_buckets, shared_pos=True, - dropout=0.1): - super(T5CrossAttention, self).__init__() + dropout=0.0): + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -203,14 +207,14 @@ class T5CrossAttention(nn.Module): self.pos_embedding = None if shared_pos else T5RelativeEmbedding( num_buckets, num_heads, bidirectional=False) - def forward(self, + def __call__(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): e = pos_bias if self.shared_pos else self.pos_embedding( - x.size(1), x.size(1)) + x.shape[1], x.shape[1]) x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.cross_attn( self.norm2(x), context=encoder_states, mask=encoder_mask)) @@ -219,9 +223,8 @@ class T5CrossAttention(nn.Module): class T5RelativeEmbedding(nn.Module): - def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super(T5RelativeEmbedding, self).__init__() + super().__init__() self.num_buckets = num_buckets self.num_heads = num_heads self.bidirectional = bidirectional @@ -230,42 +233,55 @@ class T5RelativeEmbedding(nn.Module): # layers self.embedding = nn.Embedding(num_buckets, num_heads) - def forward(self, lq, lk): - device = self.embedding.weight.device - # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ - # torch.arange(lq).unsqueeze(1).to(device) - rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ - torch.arange(lq, device=device).unsqueeze(1) + def __call__(self, lq, lk): + # Create relative position matrix + positions_q = mx.arange(lq)[:, None] + positions_k = mx.arange(lk)[None, :] + rel_pos = positions_k - positions_q + + # Apply bucketing rel_pos = self._relative_position_bucket(rel_pos) + + # Get embeddings rel_pos_embeds = self.embedding(rel_pos) - rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( - 0) # [1, N, Lq, Lk] - return rel_pos_embeds.contiguous() + + # Reshape to [1, N, Lq, Lk] + rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1)) + rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0) + + return rel_pos_embeds def _relative_position_bucket(self, rel_pos): # preprocess if self.bidirectional: num_buckets = self.num_buckets // 2 - rel_buckets = (rel_pos > 0).long() * num_buckets - rel_pos = torch.abs(rel_pos) + rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets + rel_pos = mx.abs(rel_pos) else: num_buckets = self.num_buckets - rel_buckets = 0 - rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32) + rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos)) # embeddings for small and large positions max_exact = num_buckets // 2 - rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / - math.log(self.max_dist / max_exact) * - (num_buckets - max_exact)).long() - rel_pos_large = torch.min( - rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) - rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + is_small = rel_pos < max_exact + + # For large positions, use log scale + rel_pos_large = max_exact + ( + mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact) + ).astype(mx.int32) + + rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1) + + # Combine small and large position buckets + rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large) + return rel_buckets class T5Encoder(nn.Module): - def __init__(self, vocab, dim, @@ -275,8 +291,8 @@ class T5Encoder(nn.Module): num_layers, num_buckets, shared_pos=True, - dropout=0.1): - super(T5Encoder, self).__init__() + dropout=0.0): + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -286,25 +302,25 @@ class T5Encoder(nn.Module): self.shared_pos = shared_pos # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ - else nn.Embedding(vocab, dim) + if isinstance(vocab, nn.Embedding): + self.token_embedding = vocab + else: + self.token_embedding = nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=True) if shared_pos else None self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList([ + self.blocks = [ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers) - ]) + ] self.norm = T5LayerNorm(dim) - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None): + def __call__(self, ids, mask=None): x = self.token_embedding(ids) x = self.dropout(x) - e = self.pos_embedding(x.size(1), - x.size(1)) if self.shared_pos else None + e = self.pos_embedding(x.shape[1], + x.shape[1]) if self.shared_pos else None for block in self.blocks: x = block(x, mask, pos_bias=e) x = self.norm(x) @@ -313,7 +329,6 @@ class T5Encoder(nn.Module): class T5Decoder(nn.Module): - def __init__(self, vocab, dim, @@ -323,8 +338,8 @@ class T5Decoder(nn.Module): num_layers, num_buckets, shared_pos=True, - dropout=0.1): - super(T5Decoder, self).__init__() + dropout=0.0): + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -334,34 +349,35 @@ class T5Decoder(nn.Module): self.shared_pos = shared_pos # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ - else nn.Embedding(vocab, dim) + if isinstance(vocab, nn.Embedding): + self.token_embedding = vocab + else: + self.token_embedding = nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( num_buckets, num_heads, bidirectional=False) if shared_pos else None self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList([ + self.blocks = [ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers) - ]) + ] self.norm = T5LayerNorm(dim) - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): - b, s = ids.size() + def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.shape # causal mask if mask is None: - mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + mask = mx.tril(mx.ones((1, s, s))) elif mask.ndim == 2: - mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + # Expand mask properly + mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s))) # layers x = self.token_embedding(ids) x = self.dropout(x) - e = self.pos_embedding(x.size(1), - x.size(1)) if self.shared_pos else None + e = self.pos_embedding(x.shape[1], + x.shape[1]) if self.shared_pos else None for block in self.blocks: x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) x = self.norm(x) @@ -370,7 +386,6 @@ class T5Decoder(nn.Module): class T5Model(nn.Module): - def __init__(self, vocab_size, dim, @@ -381,8 +396,8 @@ class T5Model(nn.Module): decoder_layers, num_buckets, shared_pos=True, - dropout=0.1): - super(T5Model, self).__init__() + dropout=0.0): + super().__init__() self.vocab_size = vocab_size self.dim = dim self.dim_attn = dim_attn @@ -402,23 +417,63 @@ class T5Model(nn.Module): shared_pos, dropout) self.head = nn.Linear(dim, vocab_size, bias=False) - # initialize weights - self.apply(init_weights) - - def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): x = self.encoder(encoder_ids, encoder_mask) x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) x = self.head(x) return x +def init_mlx_weights(module, key): + """Initialize weights for T5 model components to match PyTorch initialization""" + + def normal(key, shape, std=1.0): + return mx.random.normal(key, shape) * std + + if isinstance(module, T5LayerNorm): + module.weight = mx.ones_like(module.weight) + elif isinstance(module, nn.Embedding): + key = mx.random.split(key, 1)[0] + module.weight = normal(key, module.weight.shape, std=1.0) + elif isinstance(module, T5FeedForward): + # Match PyTorch initialization + key1, key2, key3 = mx.random.split(key, 3) + module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape, + std=module.dim**-0.5) + module.fc1.weight = normal(key2, module.fc1.weight.shape, + std=module.dim**-0.5) + module.fc2.weight = normal(key3, module.fc2.weight.shape, + std=module.dim_ffn**-0.5) + elif isinstance(module, T5Attention): + # Match PyTorch initialization + key1, key2, key3, key4 = random.split(key, 4) + module.q.weight = normal(key1, module.q.weight.shape, + std=(module.dim * module.dim_attn)**-0.5) + module.k.weight = normal(key2, module.k.weight.shape, + std=module.dim**-0.5) + module.v.weight = normal(key3, module.v.weight.shape, + std=module.dim**-0.5) + module.o.weight = normal(key4, module.o.weight.shape, + std=(module.num_heads * module.dim_attn)**-0.5) + elif isinstance(module, T5RelativeEmbedding): + key = mx.random.split(key, 1)[0] + module.embedding.weight = normal(key, module.embedding.weight.shape, + std=(2 * module.num_buckets * module.num_heads)**-0.5) + elif isinstance(module, nn.Linear): + # Generic linear layer initialization + key = mx.random.split(key, 1)[0] + fan_in = module.weight.shape[1] + bound = 1.0 / math.sqrt(fan_in) + module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound) + + return module + + def _t5(name, encoder_only=False, decoder_only=False, return_tokenizer=False, tokenizer_kwargs={}, - dtype=torch.float32, - device='cpu', **kwargs): # sanity check assert not (encoder_only and decoder_only) @@ -438,11 +493,11 @@ def _t5(name, model_cls = T5Model # init model - with torch.device(device): - model = model_cls(**kwargs) - - # set device - model = model.to(dtype=dtype, device=device) + model = model_cls(**kwargs) + + # Initialize weights properly + key = mx.random.key(0) + model = init_mlx_weights(model, key) # init tokenizer if return_tokenizer: @@ -464,50 +519,99 @@ def umt5_xxl(**kwargs): decoder_layers=24, num_buckets=32, shared_pos=False, - dropout=0.1) + dropout=0.0) cfg.update(**kwargs) return _t5('umt5-xxl', **cfg) class T5EncoderModel: - def __init__( self, text_len, - dtype=torch.float32, - device='mps' if torch.backends.mps.is_available() else 'cpu', checkpoint_path=None, tokenizer_path=None, - shard_fn=None, ): self.text_len = text_len - self.dtype = dtype - self.device = device self.checkpoint_path = checkpoint_path self.tokenizer_path = tokenizer_path # init model model = umt5_xxl( encoder_only=True, - return_tokenizer=False, - dtype=dtype, - device=device).eval().requires_grad_(False) - logging.info(f'loading {checkpoint_path}') - model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + return_tokenizer=False) + + if checkpoint_path: + logging.info(f'loading {checkpoint_path}') + # Load weights - assuming MLX format checkpoint + weights = mx.load(checkpoint_path) + model.update(tree_unflatten(list(weights.items()))) + self.model = model - if shard_fn is not None: - self.model = shard_fn(self.model, sync_module_states=False) - else: - self.model.to(self.device) + # init tokenizer + from .tokenizers import HuggingfaceTokenizer self.tokenizer = HuggingfaceTokenizer( - name=tokenizer_path, seq_len=text_len, clean='whitespace') + name=tokenizer_path if tokenizer_path else 'google/umt5-xxl', + seq_len=text_len, + clean='whitespace') - def __call__(self, texts, device): - ids, mask = self.tokenizer( + def __call__(self, texts): + # Handle single string input + if isinstance(texts, str): + texts = [texts] + + # Tokenize texts + tokenizer_output = self.tokenizer( texts, return_mask=True, add_special_tokens=True) - ids = ids.to(device) - mask = mask.to(device) - seq_lens = mask.gt(0).sum(dim=1).long() + + # Handle different tokenizer output formats + if isinstance(tokenizer_output, tuple): + ids, mask = tokenizer_output + else: + # Assuming dict output with 'input_ids' and 'attention_mask' + ids = tokenizer_output['input_ids'] + mask = tokenizer_output['attention_mask'] + + # Convert to MLX arrays if not already + if not isinstance(ids, mx.array): + ids = mx.array(ids) + if not isinstance(mask, mx.array): + mask = mx.array(mask) + + # Get sequence lengths + seq_lens = mx.sum(mask > 0, axis=1) + + # Run encoder context = self.model(ids, mask) - return [u[:v] for u, v in zip(context, seq_lens)] + + # Return variable length outputs + # Convert seq_lens to Python list for indexing + if seq_lens.ndim == 0: # Single value + seq_lens_list = [seq_lens.item()] + else: + seq_lens_list = seq_lens.tolist() + + return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))] + + +# Utility function to convert PyTorch checkpoint to MLX +def convert_pytorch_checkpoint(pytorch_path, mlx_path): + """Convert PyTorch checkpoint to MLX format""" + import torch + + # Load PyTorch checkpoint + pytorch_state = torch.load(pytorch_path, map_location='cpu') + + # Convert to numpy then to MLX + mlx_state = {} + for key, value in pytorch_state.items(): + if isinstance(value, torch.Tensor): + # Handle the key mapping if needed + mlx_key = key + # Convert tensor to MLX array + mlx_state[mlx_key] = mx.array(value.numpy()) + + # Save MLX checkpoint + mx.save(mlx_path, mlx_state) + + return mlx_state \ No newline at end of file diff --git a/video/Wan2.1/wan/modules/t5_mlx.py b/video/Wan2.1/wan/modules/t5_mlx.py deleted file mode 100644 index 735e2c71..00000000 --- a/video/Wan2.1/wan/modules/t5_mlx.py +++ /dev/null @@ -1,617 +0,0 @@ -# Modified from transformers.models.t5.modeling_t5 for MLX -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import logging -import math -from typing import Optional, Tuple, List - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from mlx.utils import tree_unflatten - -from .tokenizers import HuggingfaceTokenizer - -__all__ = [ - 'T5Model', - 'T5Encoder', - 'T5Decoder', - 'T5EncoderModel', -] - - -def fp16_clamp(x): - if x.dtype == mx.float16: - # Use same clamping as PyTorch for consistency - clamp = 65504.0 # max value for float16 - return mx.clip(x, -clamp, clamp) - return x - - -class GELU(nn.Module): - def __call__(self, x): - return 0.5 * x * (1.0 + mx.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0)))) - - -class T5LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-6): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = mx.ones((dim,)) - - def __call__(self, x): - # Match PyTorch's approach: convert to float32 for stability - x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x - variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True) - x_norm = x_float * mx.rsqrt(variance + self.eps) - # Convert back to original dtype - if x.dtype == mx.float16: - x_norm = x_norm.astype(mx.float16) - return self.weight * x_norm - - -class T5Attention(nn.Module): - def __init__(self, dim, dim_attn, num_heads, dropout=0.0): - assert dim_attn % num_heads == 0 - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.num_heads = num_heads - self.head_dim = dim_attn // num_heads - - # layers - self.q = nn.Linear(dim, dim_attn, bias=False) - self.k = nn.Linear(dim, dim_attn, bias=False) - self.v = nn.Linear(dim, dim_attn, bias=False) - self.o = nn.Linear(dim_attn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def __call__(self, x, context=None, mask=None, pos_bias=None): - """ - x: [B, L1, C]. - context: [B, L2, C] or None. - mask: [B, L2] or [B, L1, L2] or None. - """ - # check inputs - context = x if context is None else context - b, l1, _ = x.shape - _, l2, _ = context.shape - n, c = self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x).reshape(b, l1, n, c) - k = self.k(context).reshape(b, l2, n, c) - v = self.v(context).reshape(b, l2, n, c) - - # transpose for attention: [B, N, L, C] - q = mx.transpose(q, (0, 2, 1, 3)) - k = mx.transpose(k, (0, 2, 1, 3)) - v = mx.transpose(v, (0, 2, 1, 3)) - - # compute attention (T5 does not use scaling) - attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2] - - # add position bias if provided - if pos_bias is not None: - attn = attn + pos_bias - - # apply mask - if mask is not None: - if mask.ndim == 2: - # [B, L2] -> [B, 1, 1, L2] - mask = mask[:, None, None, :] - elif mask.ndim == 3: - # [B, L1, L2] -> [B, 1, L1, L2] - mask = mask[:, None, :, :] - # Use very negative value that works well with float16 - min_value = -65504.0 if attn.dtype == mx.float16 else -1e9 - attn = mx.where(mask == 0, min_value, attn) - - # softmax and apply attention - attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype) - attn = self.dropout(attn) - - # apply attention to values - x = mx.matmul(attn, v) # [B, N, L1, C] - - # transpose back and reshape - x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C] - x = x.reshape(b, l1, -1) - - # output projection - x = self.o(x) - x = self.dropout(x) - return x - - -class T5FeedForward(nn.Module): - def __init__(self, dim, dim_ffn, dropout=0.0): - super().__init__() - self.dim = dim - self.dim_ffn = dim_ffn - - # layers - self.gate_proj = nn.Linear(dim, dim_ffn, bias=False) - self.gate_act = GELU() - self.fc1 = nn.Linear(dim, dim_ffn, bias=False) - self.fc2 = nn.Linear(dim_ffn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def __call__(self, x): - gate = self.gate_act(self.gate_proj(x)) - x = self.fc1(x) * gate - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class T5SelfAttention(nn.Module): - def __init__(self, - dim, - dim_attn, - dim_ffn, - num_heads, - num_buckets, - shared_pos=True, - dropout=0.0): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=True) - - def __call__(self, x, mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding( - x.shape[1], x.shape[1]) - x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.ffn(self.norm2(x))) - return x - - -class T5CrossAttention(nn.Module): - def __init__(self, - dim, - dim_attn, - dim_ffn, - num_heads, - num_buckets, - shared_pos=True, - dropout=0.0): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm3 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=False) - - def __call__(self, - x, - mask=None, - encoder_states=None, - encoder_mask=None, - pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding( - x.shape[1], x.shape[1]) - x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.cross_attn( - self.norm2(x), context=encoder_states, mask=encoder_mask)) - x = fp16_clamp(x + self.ffn(self.norm3(x))) - return x - - -class T5RelativeEmbedding(nn.Module): - def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super().__init__() - self.num_buckets = num_buckets - self.num_heads = num_heads - self.bidirectional = bidirectional - self.max_dist = max_dist - - # layers - self.embedding = nn.Embedding(num_buckets, num_heads) - - def __call__(self, lq, lk): - # Create relative position matrix - positions_q = mx.arange(lq)[:, None] - positions_k = mx.arange(lk)[None, :] - rel_pos = positions_k - positions_q - - # Apply bucketing - rel_pos = self._relative_position_bucket(rel_pos) - - # Get embeddings - rel_pos_embeds = self.embedding(rel_pos) - - # Reshape to [1, N, Lq, Lk] - rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1)) - rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0) - - return rel_pos_embeds - - def _relative_position_bucket(self, rel_pos): - # preprocess - if self.bidirectional: - num_buckets = self.num_buckets // 2 - rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets - rel_pos = mx.abs(rel_pos) - else: - num_buckets = self.num_buckets - rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32) - rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos)) - - # embeddings for small and large positions - max_exact = num_buckets // 2 - is_small = rel_pos < max_exact - - # For large positions, use log scale - rel_pos_large = max_exact + ( - mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) / - math.log(self.max_dist / max_exact) * - (num_buckets - max_exact) - ).astype(mx.int32) - - rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1) - - # Combine small and large position buckets - rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large) - - return rel_buckets - - -class T5Encoder(nn.Module): - def __init__(self, - vocab, - dim, - dim_attn, - dim_ffn, - num_heads, - num_layers, - num_buckets, - shared_pos=True, - dropout=0.0): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - if isinstance(vocab, nn.Embedding): - self.token_embedding = vocab - else: - self.token_embedding = nn.Embedding(vocab, dim) - - self.pos_embedding = T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=True) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = [ - T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, - shared_pos, dropout) for _ in range(num_layers) - ] - self.norm = T5LayerNorm(dim) - - def __call__(self, ids, mask=None): - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.shape[1], - x.shape[1]) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Decoder(nn.Module): - def __init__(self, - vocab, - dim, - dim_attn, - dim_ffn, - num_heads, - num_layers, - num_buckets, - shared_pos=True, - dropout=0.0): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - if isinstance(vocab, nn.Embedding): - self.token_embedding = vocab - else: - self.token_embedding = nn.Embedding(vocab, dim) - - self.pos_embedding = T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=False) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = [ - T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, - shared_pos, dropout) for _ in range(num_layers) - ] - self.norm = T5LayerNorm(dim) - - def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None): - b, s = ids.shape - - # causal mask - if mask is None: - mask = mx.tril(mx.ones((1, s, s))) - elif mask.ndim == 2: - # Expand mask properly - mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s))) - - # layers - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.shape[1], - x.shape[1]) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Model(nn.Module): - def __init__(self, - vocab_size, - dim, - dim_attn, - dim_ffn, - num_heads, - encoder_layers, - decoder_layers, - num_buckets, - shared_pos=True, - dropout=0.0): - super().__init__() - self.vocab_size = vocab_size - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.num_buckets = num_buckets - - # layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, - num_heads, encoder_layers, num_buckets, - shared_pos, dropout) - self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, - num_heads, decoder_layers, num_buckets, - shared_pos, dropout) - self.head = nn.Linear(dim, vocab_size, bias=False) - - def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): - x = self.encoder(encoder_ids, encoder_mask) - x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) - x = self.head(x) - return x - - -def init_mlx_weights(module, key): - """Initialize weights for T5 model components to match PyTorch initialization""" - - def normal(key, shape, std=1.0): - return mx.random.normal(key, shape) * std - - if isinstance(module, T5LayerNorm): - module.weight = mx.ones_like(module.weight) - elif isinstance(module, nn.Embedding): - key = mx.random.split(key, 1)[0] - module.weight = normal(key, module.weight.shape, std=1.0) - elif isinstance(module, T5FeedForward): - # Match PyTorch initialization - key1, key2, key3 = mx.random.split(key, 3) - module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape, - std=module.dim**-0.5) - module.fc1.weight = normal(key2, module.fc1.weight.shape, - std=module.dim**-0.5) - module.fc2.weight = normal(key3, module.fc2.weight.shape, - std=module.dim_ffn**-0.5) - elif isinstance(module, T5Attention): - # Match PyTorch initialization - key1, key2, key3, key4 = random.split(key, 4) - module.q.weight = normal(key1, module.q.weight.shape, - std=(module.dim * module.dim_attn)**-0.5) - module.k.weight = normal(key2, module.k.weight.shape, - std=module.dim**-0.5) - module.v.weight = normal(key3, module.v.weight.shape, - std=module.dim**-0.5) - module.o.weight = normal(key4, module.o.weight.shape, - std=(module.num_heads * module.dim_attn)**-0.5) - elif isinstance(module, T5RelativeEmbedding): - key = mx.random.split(key, 1)[0] - module.embedding.weight = normal(key, module.embedding.weight.shape, - std=(2 * module.num_buckets * module.num_heads)**-0.5) - elif isinstance(module, nn.Linear): - # Generic linear layer initialization - key = mx.random.split(key, 1)[0] - fan_in = module.weight.shape[1] - bound = 1.0 / math.sqrt(fan_in) - module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound) - - return module - - -def _t5(name, - encoder_only=False, - decoder_only=False, - return_tokenizer=False, - tokenizer_kwargs={}, - **kwargs): - # sanity check - assert not (encoder_only and decoder_only) - - # params - if encoder_only: - model_cls = T5Encoder - kwargs['vocab'] = kwargs.pop('vocab_size') - kwargs['num_layers'] = kwargs.pop('encoder_layers') - _ = kwargs.pop('decoder_layers') - elif decoder_only: - model_cls = T5Decoder - kwargs['vocab'] = kwargs.pop('vocab_size') - kwargs['num_layers'] = kwargs.pop('decoder_layers') - _ = kwargs.pop('encoder_layers') - else: - model_cls = T5Model - - # init model - model = model_cls(**kwargs) - - # Initialize weights properly - key = mx.random.key(0) - model = init_mlx_weights(model, key) - - # init tokenizer - if return_tokenizer: - from .tokenizers import HuggingfaceTokenizer - tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) - return model, tokenizer - else: - return model - - -def umt5_xxl(**kwargs): - cfg = dict( - vocab_size=256384, - dim=4096, - dim_attn=4096, - dim_ffn=10240, - num_heads=64, - encoder_layers=24, - decoder_layers=24, - num_buckets=32, - shared_pos=False, - dropout=0.0) - cfg.update(**kwargs) - return _t5('umt5-xxl', **cfg) - - -class T5EncoderModel: - def __init__( - self, - text_len, - checkpoint_path=None, - tokenizer_path=None, - ): - self.text_len = text_len - self.checkpoint_path = checkpoint_path - self.tokenizer_path = tokenizer_path - - # init model - model = umt5_xxl( - encoder_only=True, - return_tokenizer=False) - - if checkpoint_path: - logging.info(f'loading {checkpoint_path}') - # Load weights - assuming MLX format checkpoint - weights = mx.load(checkpoint_path) - model.update(tree_unflatten(list(weights.items()))) - - self.model = model - - # init tokenizer - from .tokenizers import HuggingfaceTokenizer - self.tokenizer = HuggingfaceTokenizer( - name=tokenizer_path if tokenizer_path else 'google/umt5-xxl', - seq_len=text_len, - clean='whitespace') - - def __call__(self, texts): - # Handle single string input - if isinstance(texts, str): - texts = [texts] - - # Tokenize texts - tokenizer_output = self.tokenizer( - texts, return_mask=True, add_special_tokens=True) - - # Handle different tokenizer output formats - if isinstance(tokenizer_output, tuple): - ids, mask = tokenizer_output - else: - # Assuming dict output with 'input_ids' and 'attention_mask' - ids = tokenizer_output['input_ids'] - mask = tokenizer_output['attention_mask'] - - # Convert to MLX arrays if not already - if not isinstance(ids, mx.array): - ids = mx.array(ids) - if not isinstance(mask, mx.array): - mask = mx.array(mask) - - # Get sequence lengths - seq_lens = mx.sum(mask > 0, axis=1) - - # Run encoder - context = self.model(ids, mask) - - # Return variable length outputs - # Convert seq_lens to Python list for indexing - if seq_lens.ndim == 0: # Single value - seq_lens_list = [seq_lens.item()] - else: - seq_lens_list = seq_lens.tolist() - - return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))] - - -# Utility function to convert PyTorch checkpoint to MLX -def convert_pytorch_checkpoint(pytorch_path, mlx_path): - """Convert PyTorch checkpoint to MLX format""" - import torch - - # Load PyTorch checkpoint - pytorch_state = torch.load(pytorch_path, map_location='cpu') - - # Convert to numpy then to MLX - mlx_state = {} - for key, value in pytorch_state.items(): - if isinstance(value, torch.Tensor): - # Handle the key mapping if needed - mlx_key = key - # Convert tensor to MLX array - mlx_state[mlx_key] = mx.array(value.numpy()) - - # Save MLX checkpoint - mx.save(mlx_path, mlx_state) - - return mlx_state \ No newline at end of file diff --git a/video/Wan2.1/wan/modules/vae.py b/video/Wan2.1/wan/modules/vae.py index 83ab676b..557619f7 100644 --- a/video/Wan2.1/wan/modules/vae.py +++ b/video/Wan2.1/wan/modules/vae.py @@ -1,13 +1,11 @@ -# Original PyTorch implementation of Wan VAE - -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +from typing import Optional, List, Tuple -import torch -import torch.cuda.amp as amp -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from mlx.utils import tree_unflatten __all__ = [ 'WanVAE', @@ -17,66 +15,89 @@ CACHE_T = 2 debug_line = 0 -def debug(name, x): - global debug_line - print(f"LINE {debug_line}: {name}: shape = {tuple(x.shape)}, mean = {x.mean().item():.4f}, std = {x.std().item():.4f}") - debug_line += 1 - return x - class CausalConv3d(nn.Conv3d): """ - Causal 3d convolusion. + Causal 3d convolution for MLX. + Expects input in BTHWC format (batch, time, height, width, channels). """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Padding order: (W, W, H, H, T, 0) self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def __call__(self, x, cache_x=None): padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - - result = super().forward(x) - debug("TORCH x after conv3d", result) + x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis + padding[4] -= cache_x.shape[1] + + # Pad in BTHWC format + pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]), + (padding[0], padding[1]), (0, 0)] + x = mx.pad(x, pad_width) + + result = super().__call__(x) return result class RMS_norm(nn.Module): - def __init__(self, dim, channel_first=True, images=True, bias=False): + def __init__(self, dim, channel_first=False, images=True, bias=False): super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - self.channel_first = channel_first + self.images = images self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. - def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias + # Just keep as 1D - let broadcasting do its magic + self.gamma = mx.ones((dim,)) + self.bias = mx.zeros((dim,)) if bias else 0. + + def __call__(self, x): + # F.normalize in PyTorch does L2 normalization, not RMS! + # For NHWC/BTHWC format, normalize along the last axis + # L2 norm: sqrt(sum(x^2)) + norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6) + x = x / norm + return x * self.scale * self.gamma + self.bias -class Upsample(nn.Upsample): +class Upsample(nn.Module): + """ + Upsampling layer that matches PyTorch's behavior. + """ + def __init__(self, scale_factor, mode='nearest-exact'): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode # mode is now unused, but kept for signature consistency - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - result = super().forward(x.float()).type_as(x) - debug("TORCH x after upsample", result) - return result + def __call__(self, x): + # For NHWC format (n, h, w, c) + + # NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact' + # is equivalent to a simple repeat operation. The previous coordinate-based + # sampling was not correct for this model and caused the divergence. + + scale_h, scale_w = self.scale_factor + + out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension + out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension + + return out +class AsymmetricPad(nn.Module): + """A module to apply asymmetric padding, compatible with nn.Sequential.""" + def __init__(self, pad_width: tuple): + super().__init__() + self.pad_width = pad_width + def __call__(self, x): + return mx.pad(x, self.pad_width) + +# Update your Resample class to use 'nearest-exact' class Resample(nn.Module): def __init__(self, dim, mode): @@ -90,30 +111,42 @@ class Resample(nn.Module): if mode == 'upsample2d': self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) + nn.Conv2d(dim, dim // 2, 3, padding=1) + ) elif mode == 'upsample3d': self.resample = nn.Sequential( Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) + nn.Conv2d(dim, dim // 2, 3, padding=1) + ) self.time_conv = CausalConv3d( dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + # --- CORRECTED PADDING LOGIC --- elif mode == 'downsample2d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) + # Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2) + # Use the new AsymmetricPad module. + # Pad width for NHWC format is ((N), (H), (W), (C)) + # Pad H with (top, bottom) and W with (left, right) + pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0))) + conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0) + self.resample = nn.Sequential(pad_layer, conv_layer) + elif mode == 'downsample3d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) + # The spatial downsampling part uses the same logic + pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0))) + conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0) + self.resample = nn.Sequential(pad_layer, conv_layer) + self.time_conv = CausalConv3d( dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - + else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() + def __call__(self, x, feat_cache=None, feat_idx=[0]): + # The __call__ method logic remains unchanged from your original code + b, t, h, w, c = x.shape + if self.mode == 'upsample3d': if feat_cache is not None: idx = feat_idx[0] @@ -121,84 +154,50 @@ class Resample(nn.Module): feat_cache[idx] = 'Rep' feat_idx[0] += 1 else: - - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep': + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep': + cache_x = mx.concatenate([ + mx.zeros_like(cache_x), cache_x + ], axis=1) + if feat_cache[idx] == 'Rep': x = self.time_conv(x) - debug("TORCH x after time_conv", x) else: x = self.time_conv(x, feat_cache[idx]) - debug("TORCH x after time_conv with cache", x) feat_cache[idx] = cache_x feat_idx[0] += 1 - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = x.reshape(b, t, h, w, 2, c) + x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2) + x = x.reshape(b, t * 2, h, w, c) + + t = x.shape[1] + x = x.reshape(b * t, h, w, c) + x = self.resample(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) - debug("TORCH x after resample", x) + + _, h_new, w_new, c_new = x.shape + x = x.reshape(b, t, h_new, w_new, c_new) if self.mode == 'downsample3d': if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() + feat_cache[idx] = x feat_idx[0] += 1 else: - - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, -1:, :, :, :] x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) feat_cache[idx] = cache_x feat_idx[0] += 1 + return x - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - class ResidualBlock(nn.Module): @@ -209,34 +208,34 @@ class ResidualBlock(nn.Module): # layers self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), nn.SiLU(), + RMS_norm(in_dim, images=False), + nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1)) + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout) if dropout > 0 else nn.Identity(), + CausalConv3d(out_dim, out_dim, 3, padding=1) + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def __call__(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) - debug("TORCH x after shortcut", h) - for layer in self.residual: + + for i, layer in enumerate(self.residual.layers): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) x = layer(x, feat_cache[idx]) - debug("TORCH x after residual block", x) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - debug("TORCH x after residual block", x) return x + h @@ -255,32 +254,34 @@ class AttentionBlock(nn.Module): self.proj = nn.Conv2d(dim, dim, 1) # zero out the last layer params - nn.init.zeros_(self.proj.weight) + self.proj.weight = mx.zeros_like(self.proj.weight) - def forward(self, x): + def __call__(self, x): + # x is in BTHWC format identity = x - b, c, t, h, w = x.size() - x = rearrange(x, 'b c t h w -> (b t) c h w') + b, t, h, w, c = x.shape + x = x.reshape(b * t, h, w, c) # Combine batch and time x = self.norm(x) - debug("TORCH x after norm", x) # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) + qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c) + qkv = qkv.reshape(b * t, h * w, 3 * c) + q, k, v = mx.split(qkv, 3, axis=-1) + + # Reshape for attention + q = q.reshape(b * t, h * w, c) + k = k.reshape(b * t, h * w, c) + v = v.reshape(b * t, h * w, c) - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + # Scaled dot product attention + scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype)) + scores = (q @ k.transpose(0, 2, 1)) * scale + weights = mx.softmax(scores, axis=-1) + x = weights @ v + x = x.reshape(b * t, h, w, c) # output x = self.proj(x) - debug("TORCH x after proj", x) - x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + x = x.reshape(b, t, h, w, c) return x + identity @@ -321,78 +322,68 @@ class Encoder3d(nn.Module): # downsample block if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' + mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d' downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) # middle blocks self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout)) + ResidualBlock(dims[-1], dims[-1], dropout), + AttentionBlock(dims[-1]), + ResidualBlock(dims[-1], dims[-1], dropout) + ) # output blocks self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1)) + RMS_norm(dims[-1], images=False), + nn.SiLU(), + CausalConv3d(dims[-1], z_dim, 3, padding=1) + ) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def __call__(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) x = self.conv1(x, feat_cache[idx]) - debug("TORCH x after conv1 with cache", x) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv1(x) - debug("TORCH x after conv1", x) ## downsamples - for i, layer in enumerate(self.downsamples): + for i, layer in enumerate(self.downsamples.layers): if feat_cache is not None: x = layer(x, feat_cache, feat_idx) - debug("TORCH x after downsample layer", x) else: x = layer(x) - debug("TORCH x after downsample layer", x) ## middle - for layer in self.middle: + for layer in self.middle.layers: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) - debug("TORCH x after downsample layer", x) else: x = layer(x) - debug("TORCH x after downsample layer", x) ## head - for layer in self.head: + for i, layer in enumerate(self.head.layers): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) x = layer(x, feat_cache[idx]) - debug("TORCH x after downsample layer", x) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - debug("TORCH x after downsample layer", x) return x @@ -423,8 +414,10 @@ class Decoder3d(nn.Module): # middle blocks self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout)) + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout) + ) # upsample blocks upsamples = [] @@ -443,77 +436,66 @@ class Decoder3d(nn.Module): mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) # output blocks self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + RMS_norm(dims[-1], images=False), + nn.SiLU(), + CausalConv3d(dims[-1], 3, 3, padding=1) + ) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def __call__(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) x = self.conv1(x, feat_cache[idx]) - debug("TORCH x after conv1 with cache", x) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv1(x) - debug("TORCH x after conv1", x) ## middle - for layer in self.middle: + for layer in self.middle.layers: if isinstance(layer, ResidualBlock) and feat_cache is not None: x = layer(x, feat_cache, feat_idx) - debug("TORCH x after middle layer", x) else: x = layer(x) - debug("TORCH x after middle layer", x) ## upsamples - for layer in self.upsamples: + for layer in self.upsamples.layers: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) - debug("TORCH x after upsample layer", x) else: x = layer(x) - debug("TORCH x after upsample layer", x) ## head - for layer in self.head: + for i, layer in enumerate(self.head.layers): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, -CACHE_T:, :, :, :] + if cache_x.shape[1] < 2 and feat_cache[idx] is not None: + cache_x = mx.concatenate([ + feat_cache[idx][:, -1:, :, :, :], cache_x + ], axis=1) x = layer(x, feat_cache[idx]) - debug("TORCH x after head layer", x) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - debug("TORCH x after head layer", x) return x def count_conv3d(model): count = 0 - for m in model.modules(): - if isinstance(m, CausalConv3d): + for name, module in model.named_modules(): + if isinstance(module, CausalConv3d): count += 1 return count @@ -545,78 +527,85 @@ class WanVAE_(nn.Module): self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) - def forward(self, x): - mu, log_var = self.encode(x) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z) - return x_recon, mu, log_var - def encode(self, x, scale): + # x is in BTHWC format self.clear_cache() ## cache - t = x.shape[2] + t = x.shape[1] iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... + ## Split encode input x by time into 1, 4, 4, 4.... for i in range(iter_): self._enc_conv_idx = [0] if i == 0: out = self.encoder( - x[:, :, :1, :, :], + x[:, :1, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) + out = mx.concatenate([out, out_], axis=1) + + z = self.conv1(out) + mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension + + if isinstance(scale[0], mx.array): + # Reshape scale for broadcasting in BTHWC format + scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim) + scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim) + mu = (mu - scale_mean) * scale_std else: mu = (mu - scale[0]) * scale[1] self.clear_cache() - return mu + return mu, log_var def decode(self, z, scale): + # z is in BTHWC format self.clear_cache() - # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) + if isinstance(scale[0], mx.array): + scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim) + scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim) + z = z / scale_std + scale_mean else: z = z / scale[1] + scale[0] - iter_ = z.shape[2] + iter_ = z.shape[1] x = self.conv2(z) for i in range(iter_): self._conv_idx = [0] if i == 0: out = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, i:i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, i:i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) + out = mx.concatenate([out, out_], axis=1) self.clear_cache() return out def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) + std = mx.exp(0.5 * log_var) + eps = mx.random.normal(std.shape) return eps * std + mu + def __call__(self, x): + mu, log_var = self.encode(x, self.scale) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z, self.scale) + return x_recon, mu, log_var + def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) + mu, log_var = self.encode(imgs, self.scale) if deterministic: return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) + std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0)) + return mu + std * mx.random.normal(std.shape) def clear_cache(self): self._conv_num = count_conv3d(self.decoder) @@ -628,7 +617,7 @@ class WanVAE_(nn.Module): self._enc_feat_map = [None] * self._enc_conv_num -def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): +def _video_vae(pretrained_path=None, z_dim=None, **kwargs): """ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. """ @@ -644,13 +633,13 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): cfg.update(**kwargs) # init model - with torch.device('meta'): - model = WanVAE_(**cfg) + model = WanVAE_(**cfg) # load checkpoint - logging.info(f'loading {pretrained_path}') - model.load_state_dict( - torch.load(pretrained_path, map_location=device), assign=True) + if pretrained_path: + logging.info(f'loading {pretrained_path}') + weights = mx.load(pretrained_path) + model.update(tree_unflatten(list(weights.items()))) return model @@ -660,10 +649,8 @@ class WanVAE: def __init__(self, z_dim=16, vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cpu"): + dtype=mx.float32): self.dtype = dtype - self.device = device mean = [ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, @@ -673,30 +660,59 @@ class WanVAE: 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 ] - self.mean = torch.tensor(mean, dtype=dtype, device=device) - self.std = torch.tensor(std, dtype=dtype, device=device) + self.mean = mx.array(mean, dtype=dtype) + self.std = mx.array(std, dtype=dtype) self.scale = [self.mean, 1.0 / self.std] # init model self.model = _video_vae( pretrained_path=vae_pth, z_dim=z_dim, - ).eval().requires_grad_(False).to(device) + ) def encode(self, videos): """ videos: A list of videos each with shape [C, T, H, W]. + Returns: List of encoded videos in [C, T, H, W] format. """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] + encoded = [] + for video in videos: + # Convert CTHW -> BTHWC + x = mx.expand_dims(video, axis=0) # Add batch dimension + x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC + + # Encode + z = self.model.encode(x, self.scale)[0] # Get mu only + + # Convert back BTHWC -> CTHW and remove batch dimension + z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW + z = z.squeeze(0) # Remove batch dimension -> CTHW + + encoded.append(z.astype(mx.float32)) + + return encoded def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + """ + zs: A list of latent codes each with shape [C, T, H, W]. + Returns: List of decoded videos in [C, T, H, W] format. + """ + decoded = [] + for z in zs: + # Convert CTHW -> BTHWC + x = mx.expand_dims(z, axis=0) # Add batch dimension + x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC + + # Decode + x = self.model.decode(x, self.scale) + + # Convert back BTHWC -> CTHW and remove batch dimension + x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW + x = x.squeeze(0) # Remove batch dimension -> CTHW + + # Clamp values + x = mx.clip(x, -1, 1) + + decoded.append(x.astype(mx.float32)) + + return decoded \ No newline at end of file diff --git a/video/Wan2.1/wan/modules/vae_mlx.py b/video/Wan2.1/wan/modules/vae_mlx.py deleted file mode 100644 index a6a60688..00000000 --- a/video/Wan2.1/wan/modules/vae_mlx.py +++ /dev/null @@ -1,719 +0,0 @@ -# vae_mlx_final.py -import logging -from typing import Optional, List, Tuple - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from mlx.utils import tree_unflatten - -__all__ = [ - 'WanVAE', -] - -CACHE_T = 2 - -debug_line = 0 - - -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolution for MLX. - Expects input in BTHWC format (batch, time, height, width, channels). - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Padding order: (W, W, H, H, T, 0) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) - - def __call__(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis - padding[4] -= cache_x.shape[1] - - # Pad in BTHWC format - pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]), - (padding[0], padding[1]), (0, 0)] - x = mx.pad(x, pad_width) - - result = super().__call__(x) - return result - - -class RMS_norm(nn.Module): - - def __init__(self, dim, channel_first=False, images=True, bias=False): - super().__init__() - self.channel_first = channel_first - self.images = images - self.scale = dim**0.5 - - # Just keep as 1D - let broadcasting do its magic - self.gamma = mx.ones((dim,)) - self.bias = mx.zeros((dim,)) if bias else 0. - - def __call__(self, x): - # F.normalize in PyTorch does L2 normalization, not RMS! - # For NHWC/BTHWC format, normalize along the last axis - # L2 norm: sqrt(sum(x^2)) - norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6) - x = x / norm - return x * self.scale * self.gamma + self.bias - - -class Upsample(nn.Module): - """ - Upsampling layer that matches PyTorch's behavior. - """ - def __init__(self, scale_factor, mode='nearest-exact'): - super().__init__() - self.scale_factor = scale_factor - self.mode = mode # mode is now unused, but kept for signature consistency - - def __call__(self, x): - # For NHWC format (n, h, w, c) - - # NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact' - # is equivalent to a simple repeat operation. The previous coordinate-based - # sampling was not correct for this model and caused the divergence. - - scale_h, scale_w = self.scale_factor - - out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension - out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension - - return out - -class AsymmetricPad(nn.Module): - """A module to apply asymmetric padding, compatible with nn.Sequential.""" - def __init__(self, pad_width: tuple): - super().__init__() - self.pad_width = pad_width - - def __call__(self, x): - return mx.pad(x, self.pad_width) - -# Update your Resample class to use 'nearest-exact' -class Resample(nn.Module): - - def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == 'upsample2d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - elif mode == 'upsample3d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - # --- CORRECTED PADDING LOGIC --- - elif mode == 'downsample2d': - # Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2) - # Use the new AsymmetricPad module. - # Pad width for NHWC format is ((N), (H), (W), (C)) - # Pad H with (top, bottom) and W with (left, right) - pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0))) - conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0) - self.resample = nn.Sequential(pad_layer, conv_layer) - - elif mode == 'downsample3d': - # The spatial downsampling part uses the same logic - pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0))) - conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0) - self.resample = nn.Sequential(pad_layer, conv_layer) - - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - - else: - self.resample = nn.Identity() - - def __call__(self, x, feat_cache=None, feat_idx=[0]): - # The __call__ method logic remains unchanged from your original code - b, t, h, w, c = x.shape - - if self.mode == 'upsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = 'Rep' - feat_idx[0] += 1 - else: - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep': - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep': - cache_x = mx.concatenate([ - mx.zeros_like(cache_x), cache_x - ], axis=1) - - if feat_cache[idx] == 'Rep': - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, t, h, w, 2, c) - x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2) - x = x.reshape(b, t * 2, h, w, c) - - t = x.shape[1] - x = x.reshape(b * t, h, w, c) - - x = self.resample(x) - - _, h_new, w_new, c_new = x.shape - x = x.reshape(b, t, h_new, w_new, c_new) - - if self.mode == 'downsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x - feat_idx[0] += 1 - else: - cache_x = x[:, -1:, :, :, :] - x = self.time_conv( - mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - return x - - -class ResidualBlock(nn.Module): - - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), - nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), - nn.SiLU(), - nn.Dropout(dropout) if dropout > 0 else nn.Identity(), - CausalConv3d(out_dim, out_dim, 3, padding=1) - ) - - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() - - def __call__(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) - - for i, layer in enumerate(self.residual.layers): - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - self.proj.weight = mx.zeros_like(self.proj.weight) - - def __call__(self, x): - # x is in BTHWC format - identity = x - b, t, h, w, c = x.shape - x = x.reshape(b * t, h, w, c) # Combine batch and time - x = self.norm(x) - # compute query, key, value - qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c) - qkv = qkv.reshape(b * t, h * w, 3 * c) - q, k, v = mx.split(qkv, 3, axis=-1) - - # Reshape for attention - q = q.reshape(b * t, h * w, c) - k = k.reshape(b * t, h * w, c) - v = v.reshape(b * t, h * w, c) - - # Scaled dot product attention - scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype)) - scores = (q @ k.transpose(0, 2, 1)) * scale - weights = mx.softmax(scores, axis=-1) - x = weights @ v - x = x.reshape(b * t, h, w, c) - - # output - x = self.proj(x) - x = x.reshape(b, t, h, w, c) - return x + identity - - -class Encoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - - # dimensions - dims = [dim * u for u in [1] + dim_mult] - scale = 1.0 - - # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) - - # downsample blocks - downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - for _ in range(num_res_blocks): - downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - downsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # downsample block - if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d' - downsamples.append(Resample(out_dim, mode=mode)) - scale /= 2.0 - - self.downsamples = nn.Sequential(*downsamples) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[-1], dims[-1], dropout), - AttentionBlock(dims[-1]), - ResidualBlock(dims[-1], dims[-1], dropout) - ) - - # output blocks - self.head = nn.Sequential( - RMS_norm(dims[-1], images=False), - nn.SiLU(), - CausalConv3d(dims[-1], z_dim, 3, padding=1) - ) - - def __call__(self, x, feat_cache=None, feat_idx=[0]): - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## downsamples - for i, layer in enumerate(self.downsamples.layers): - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## middle - for layer in self.middle.layers: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for i, layer in enumerate(self.head.layers): - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -class Decoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_upsample = temperal_upsample - - # dimensions - dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2**(len(dim_mult) - 2) - - # init block - self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), - AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout) - ) - - # upsample blocks - upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - if i == 1 or i == 2 or i == 3: - in_dim = in_dim // 2 - for _ in range(num_res_blocks + 1): - upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - upsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # upsample block - if i != len(dim_mult) - 1: - mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' - upsamples.append(Resample(out_dim, mode=mode)) - scale *= 2.0 - - self.upsamples = nn.Sequential(*upsamples) - - # output blocks - self.head = nn.Sequential( - RMS_norm(dims[-1], images=False), - nn.SiLU(), - CausalConv3d(dims[-1], 3, 3, padding=1) - ) - - def __call__(self, x, feat_cache=None, feat_idx=[0]): - ## conv1 - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## middle - for layer in self.middle.layers: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples.layers: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for i, layer in enumerate(self.head.layers): - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, -CACHE_T:, :, :, :] - if cache_x.shape[1] < 2 and feat_cache[idx] is not None: - cache_x = mx.concatenate([ - feat_cache[idx][:, -1:, :, :, :], cache_x - ], axis=1) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -def count_conv3d(model): - count = 0 - for name, module in model.named_modules(): - if isinstance(module, CausalConv3d): - count += 1 - return count - - -class WanVAE_(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - self.temperal_upsample = temperal_downsample[::-1] - - # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) - self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) - self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) - - def encode(self, x, scale): - # x is in BTHWC format - self.clear_cache() - ## cache - t = x.shape[1] - iter_ = 1 + (t - 1) // 4 - ## Split encode input x by time into 1, 4, 4, 4.... - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self.encoder( - x[:, :1, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - else: - out_ = self.encoder( - x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - out = mx.concatenate([out, out_], axis=1) - - z = self.conv1(out) - mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension - - if isinstance(scale[0], mx.array): - # Reshape scale for broadcasting in BTHWC format - scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim) - scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim) - mu = (mu - scale_mean) * scale_std - else: - mu = (mu - scale[0]) * scale[1] - self.clear_cache() - - return mu, log_var - - def decode(self, z, scale): - # z is in BTHWC format - self.clear_cache() - if isinstance(scale[0], mx.array): - scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim) - scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim) - z = z / scale_std + scale_mean - else: - z = z / scale[1] + scale[0] - iter_ = z.shape[1] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder( - x[:, i:i + 1, :, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - else: - out_ = self.decoder( - x[:, i:i + 1, :, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - out = mx.concatenate([out, out_], axis=1) - self.clear_cache() - return out - - def reparameterize(self, mu, log_var): - std = mx.exp(0.5 * log_var) - eps = mx.random.normal(std.shape) - return eps * std + mu - - def __call__(self, x): - mu, log_var = self.encode(x, self.scale) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z, self.scale) - return x_recon, mu, log_var - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs, self.scale) - if deterministic: - return mu - std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0)) - return mu + std * mx.random.normal(std.shape) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - #cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - -def _video_vae(pretrained_path=None, z_dim=None, **kwargs): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params - cfg = dict( - dim=96, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[False, True, True], - dropout=0.0) - cfg.update(**kwargs) - - # init model - model = WanVAE_(**cfg) - - # load checkpoint - if pretrained_path: - logging.info(f'loading {pretrained_path}') - weights = mx.load(pretrained_path) - model.update(tree_unflatten(list(weights.items()))) - - return model - - -class WanVAE: - - def __init__(self, - z_dim=16, - vae_pth='cache/vae_step_411000.pth', - dtype=mx.float32): - self.dtype = dtype - - mean = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 - ] - std = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 - ] - self.mean = mx.array(mean, dtype=dtype) - self.std = mx.array(std, dtype=dtype) - self.scale = [self.mean, 1.0 / self.std] - - # init model - self.model = _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - ) - - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - Returns: List of encoded videos in [C, T, H, W] format. - """ - encoded = [] - for video in videos: - # Convert CTHW -> BTHWC - x = mx.expand_dims(video, axis=0) # Add batch dimension - x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC - - # Encode - z = self.model.encode(x, self.scale)[0] # Get mu only - - # Convert back BTHWC -> CTHW and remove batch dimension - z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW - z = z.squeeze(0) # Remove batch dimension -> CTHW - - encoded.append(z.astype(mx.float32)) - - return encoded - - def decode(self, zs): - """ - zs: A list of latent codes each with shape [C, T, H, W]. - Returns: List of decoded videos in [C, T, H, W] format. - """ - decoded = [] - for z in zs: - # Convert CTHW -> BTHWC - x = mx.expand_dims(z, axis=0) # Add batch dimension - x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC - - # Decode - x = self.model.decode(x, self.scale) - - # Convert back BTHWC -> CTHW and remove batch dimension - x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW - x = x.squeeze(0) # Remove batch dimension -> CTHW - - # Clamp values - x = mx.clip(x, -1, 1) - - decoded.append(x.astype(mx.float32)) - - return decoded \ No newline at end of file diff --git a/video/Wan2.1/wan/t5_torch_to_sf.py b/video/Wan2.1/wan/t5_torch_to_sf.py index 87bf4175..8e0615ef 100644 --- a/video/Wan2.1/wan/t5_torch_to_sf.py +++ b/video/Wan2.1/wan/t5_torch_to_sf.py @@ -4,7 +4,7 @@ from safetensors.torch import save_file from pathlib import Path import json -from wan.modules.t5_mlx import T5Model +from wan.modules.t5 import T5Model def convert_pickle_to_safetensors( diff --git a/video/Wan2.1/wan/text2video_mlx.py b/video/Wan2.1/wan/text2video.py similarity index 97% rename from video/Wan2.1/wan/text2video_mlx.py rename to video/Wan2.1/wan/text2video.py index 110c19cc..a37d97e6 100644 --- a/video/Wan2.1/wan/text2video_mlx.py +++ b/video/Wan2.1/wan/text2video.py @@ -11,11 +11,11 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .modules.model_mlx import WanModel -from .modules.t5_mlx import T5EncoderModel -from .modules.vae_mlx import WanVAE -from .utils.fm_solvers_mlx import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps -from .utils.fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .wan_model_io import load_wan_from_safetensors diff --git a/video/Wan2.1/wan/utils/__init__.py b/video/Wan2.1/wan/utils/__init__.py index d66111f1..6e9a339e 100644 --- a/video/Wan2.1/wan/utils/__init__.py +++ b/video/Wan2.1/wan/utils/__init__.py @@ -1,6 +1,6 @@ -from .fm_solvers_mlx import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, +from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) -from .fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler +from .fm_solvers_unipc import FlowUniPCMultistepScheduler __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', diff --git a/video/Wan2.1/wan/utils/fm_solvers_mlx.py b/video/Wan2.1/wan/utils/fm_solvers.py similarity index 100% rename from video/Wan2.1/wan/utils/fm_solvers_mlx.py rename to video/Wan2.1/wan/utils/fm_solvers.py diff --git a/video/Wan2.1/wan/utils/fm_solvers_unipc_mlx.py b/video/Wan2.1/wan/utils/fm_solvers_unipc.py similarity index 100% rename from video/Wan2.1/wan/utils/fm_solvers_unipc_mlx.py rename to video/Wan2.1/wan/utils/fm_solvers_unipc.py diff --git a/video/Wan2.1/wan/utils/utils_mlx_own.py b/video/Wan2.1/wan/utils/utils.py similarity index 100% rename from video/Wan2.1/wan/utils/utils_mlx_own.py rename to video/Wan2.1/wan/utils/utils.py