This commit is contained in:
N 2025-07-28 17:07:26 -07:00
parent a36538bf9f
commit a19ccd9004
24 changed files with 562 additions and 1897 deletions

View File

@ -66,7 +66,7 @@ This repository currently supports two Text-to-Video models (1.3B and 14B) and t
##### (1) Example:
```
python generate_mlx.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
python generate.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 516 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 871 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 294 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 628 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 208 KiB

View File

@ -10,7 +10,7 @@ from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils_mlx_own import cache_video, cache_image, str2bool
from wan.utils.utils import cache_video, cache_image, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {

View File

@ -1,6 +0,0 @@
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
```bash
bash ./test.sh <local model dir> <gpu number>
```

View File

@ -1,113 +0,0 @@
#!/bin/bash
if [ "$#" -eq 2 ]; then
MODEL_DIR=$(realpath "$1")
GPUS=$2
else
echo "Usage: $0 <local model dir> <gpu number>"
exit 1
fi
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1
PY_FILE=./generate.py
function t2v_1_3B() {
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function t2v_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function t2i_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function i2v_14B_480p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function i2v_14B_720p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p

View File

@ -1,2 +1,2 @@
from . import configs, modules
from .text2video_mlx import WanT2V
from .text2video import WanT2V

View File

@ -1,7 +1,7 @@
from .model_mlx import WanModel
from .t5_mlx import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae_mlx import WanVAE
from .vae import WanVAE
__all__ = [
'WanVAE',

View File

@ -1,11 +1,13 @@
# Modified from transformers.models.t5.modeling_t5
# Modified from transformers.models.t5.modeling_t5 for MLX
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
@ -18,59 +20,41 @@ __all__ = [
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5Model):
nn.init.normal_(m.token_embedding.weight, std=1.0)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.weight = mx.ones((dim,))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.float32]:
x = x.type_as(self.weight)
return self.weight * x
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
@ -83,7 +67,7 @@ class T5Attention(nn.Module):
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
@ -91,50 +75,72 @@ class T5Attention(nn.Module):
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1,
-1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# output
x = x.reshape(b, -1, n * c)
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
@ -142,7 +148,6 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
@ -150,8 +155,8 @@ class T5SelfAttention(nn.Module):
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5SelfAttention, self).__init__()
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
@ -167,16 +172,15 @@ class T5SelfAttention(nn.Module):
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
@ -184,8 +188,8 @@ class T5CrossAttention(nn.Module):
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5CrossAttention, self).__init__()
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
@ -203,14 +207,14 @@ class T5CrossAttention(nn.Module):
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def forward(self,
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
@ -219,9 +223,8 @@ class T5CrossAttention(nn.Module):
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
@ -230,42 +233,55 @@ class T5RelativeEmbedding(nn.Module):
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
torch.arange(lq, device=device).unsqueeze(1)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
@ -275,8 +291,8 @@ class T5Encoder(nn.Module):
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Encoder, self).__init__()
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
@ -286,25 +302,25 @@ class T5Encoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
]
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
@ -313,7 +329,6 @@ class T5Encoder(nn.Module):
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
@ -323,8 +338,8 @@ class T5Decoder(nn.Module):
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Decoder, self).__init__()
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
@ -334,34 +349,35 @@ class T5Decoder(nn.Module):
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
]
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.size()
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
@ -370,7 +386,6 @@ class T5Decoder(nn.Module):
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
@ -381,8 +396,8 @@ class T5Model(nn.Module):
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Model, self).__init__()
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
@ -402,23 +417,63 @@ class T5Model(nn.Module):
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
# initialize weights
self.apply(init_weights)
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
dtype=torch.float32,
device='cpu',
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
@ -438,11 +493,11 @@ def _t5(name,
model_cls = T5Model
# init model
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
@ -464,50 +519,99 @@ def umt5_xxl(**kwargs):
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.1)
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.float32,
device='mps' if torch.backends.mps.is_available() else 'cpu',
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean='whitespace')
return_tokenizer=False)
def __call__(self, texts, device):
ids, mask = self.tokenizer(
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
return [u[:v] for u, v in zip(context, seq_lens)]
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state

View File

@ -1,617 +0,0 @@
# Modified from transformers.models.t5.modeling_t5 for MLX
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
class GELU(nn.Module):
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
model = model_cls(**kwargs)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
checkpoint_path=None,
tokenizer_path=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False)
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state

View File

@ -1,13 +1,11 @@
# Original PyTorch implementation of Wan VAE
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
from typing import Optional, List, Tuple
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
__all__ = [
'WanVAE',
@ -17,66 +15,89 @@ CACHE_T = 2
debug_line = 0
def debug(name, x):
global debug_line
print(f"LINE {debug_line}: {name}: shape = {tuple(x.shape)}, mean = {x.mean().item():.4f}, std = {x.std().item():.4f}")
debug_line += 1
return x
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
Causal 3d convolution for MLX.
Expects input in BTHWC format (batch, time, height, width, channels).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Padding order: (W, W, H, H, T, 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
def __call__(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
padding[4] -= cache_x.shape[1]
result = super().forward(x)
debug("TORCH x after conv3d", result)
# Pad in BTHWC format
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
(padding[0], padding[1]), (0, 0)]
x = mx.pad(x, pad_width)
result = super().__call__(x)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
def __init__(self, dim, channel_first=False, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
# Just keep as 1D - let broadcasting do its magic
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else 0.
def __call__(self, x):
# F.normalize in PyTorch does L2 normalization, not RMS!
# For NHWC/BTHWC format, normalize along the last axis
# L2 norm: sqrt(sum(x^2))
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
x = x / norm
return x * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
class Upsample(nn.Module):
"""
Fix bfloat16 support for nearest neighbor interpolation.
Upsampling layer that matches PyTorch's behavior.
"""
result = super().forward(x.float()).type_as(x)
debug("TORCH x after upsample", result)
return result
def __init__(self, scale_factor, mode='nearest-exact'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode # mode is now unused, but kept for signature consistency
def __call__(self, x):
# For NHWC format (n, h, w, c)
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
# is equivalent to a simple repeat operation. The previous coordinate-based
# sampling was not correct for this model and caused the divergence.
scale_h, scale_w = self.scale_factor
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
return out
class AsymmetricPad(nn.Module):
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
def __init__(self, pad_width: tuple):
super().__init__()
self.pad_width = pad_width
def __call__(self, x):
return mx.pad(x, self.pad_width)
# Update your Resample class to use 'nearest-exact'
class Resample(nn.Module):
def __init__(self, dim, mode):
@ -90,30 +111,42 @@ class Resample(nn.Module):
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
# --- CORRECTED PADDING LOGIC ---
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
# Use the new AsymmetricPad module.
# Pad width for NHWC format is ((N), (H), (W), (C))
# Pad H with (top, bottom) and W with (left, right)
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
# The spatial downsampling part uses the same logic
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# The __call__ method logic remains unchanged from your original code
b, t, h, w, c = x.shape
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
@ -121,84 +154,50 @@ class Resample(nn.Module):
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
cache_x = mx.concatenate([
mx.zeros_like(cache_x), cache_x
], axis=1)
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
debug("TORCH x after time_conv", x)
else:
x = self.time_conv(x, feat_cache[idx])
debug("TORCH x after time_conv with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = x.reshape(b, t, h, w, 2, c)
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
debug("TORCH x after resample", x)
_, h_new, w_new, c_new = x.shape
x = x.reshape(b, t, h_new, w_new, c_new)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
cache_x = x[:, -1:, :, :, :]
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
@ -209,34 +208,34 @@ class ResidualBlock(nn.Module):
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
CausalConv3d(out_dim, out_dim, 3, padding=1)
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
debug("TORCH x after shortcut", h)
for layer in self.residual:
for i, layer in enumerate(self.residual.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
debug("TORCH x after residual block", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after residual block", x)
return x + h
@ -255,32 +254,34 @@ class AttentionBlock(nn.Module):
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
self.proj.weight = mx.zeros_like(self.proj.weight)
def forward(self, x):
def __call__(self, x):
# x is in BTHWC format
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
b, t, h, w, c = x.shape
x = x.reshape(b * t, h, w, c) # Combine batch and time
x = self.norm(x)
debug("TORCH x after norm", x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
qkv = qkv.reshape(b * t, h * w, 3 * c)
q, k, v = mx.split(qkv, 3, axis=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# Reshape for attention
q = q.reshape(b * t, h * w, c)
k = k.reshape(b * t, h * w, c)
v = v.reshape(b * t, h * w, c)
# Scaled dot product attention
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
scores = (q @ k.transpose(0, 2, 1)) * scale
weights = mx.softmax(scores, axis=-1)
x = weights @ v
x = x.reshape(b * t, h, w, c)
# output
x = self.proj(x)
debug("TORCH x after proj", x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
x = x.reshape(b, t, h, w, c)
return x + identity
@ -321,78 +322,68 @@ class Encoder3d(nn.Module):
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
ResidualBlock(dims[-1], dims[-1], dropout),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1], dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], z_dim, 3, padding=1)
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
def __call__(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
debug("TORCH x after conv1 with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
debug("TORCH x after conv1", x)
## downsamples
for i, layer in enumerate(self.downsamples):
for i, layer in enumerate(self.downsamples.layers):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after downsample layer", x)
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
## middle
for layer in self.middle:
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after downsample layer", x)
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
## head
for layer in self.head:
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
debug("TORCH x after downsample layer", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after downsample layer", x)
return x
@ -423,8 +414,10 @@ class Decoder3d(nn.Module):
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
@ -443,77 +436,66 @@ class Decoder3d(nn.Module):
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], 3, 3, padding=1)
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
def __call__(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
debug("TORCH x after conv1 with cache", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
debug("TORCH x after conv1", x)
## middle
for layer in self.middle:
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after middle layer", x)
else:
x = layer(x)
debug("TORCH x after middle layer", x)
## upsamples
for layer in self.upsamples:
for layer in self.upsamples.layers:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
debug("TORCH x after upsample layer", x)
else:
x = layer(x)
debug("TORCH x after upsample layer", x)
## head
for layer in self.head:
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
debug("TORCH x after head layer", x)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
debug("TORCH x after head layer", x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
for name, module in model.named_modules():
if isinstance(module, CausalConv3d):
count += 1
return count
@ -545,78 +527,85 @@ class WanVAE_(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
# x is in BTHWC format
self.clear_cache()
## cache
t = x.shape[2]
t = x.shape[1]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
## Split encode input x by time into 1, 4, 4, 4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
x[:, :1, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
out = mx.concatenate([out, out_], axis=1)
z = self.conv1(out)
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
if isinstance(scale[0], mx.array):
# Reshape scale for broadcasting in BTHWC format
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
mu = (mu - scale_mean) * scale_std
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
return mu, log_var
def decode(self, z, scale):
# z is in BTHWC format
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
if isinstance(scale[0], mx.array):
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
z = z / scale_std + scale_mean
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
iter_ = z.shape[1]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = mx.concatenate([out, out_], axis=1)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
std = mx.exp(0.5 * log_var)
eps = mx.random.normal(std.shape)
return eps * std + mu
def __call__(self, x):
mu, log_var = self.encode(x, self.scale)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, self.scale)
return x_recon, mu, log_var
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
mu, log_var = self.encode(imgs, self.scale)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
return mu + std * mx.random.normal(std.shape)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
@ -628,7 +617,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
@ -644,13 +633,13 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
cfg.update(**kwargs)
# init model
with torch.device('meta'):
model = WanVAE_(**cfg)
# load checkpoint
if pretrained_path:
logging.info(f'loading {pretrained_path}')
model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True)
weights = mx.load(pretrained_path)
model.update(tree_unflatten(list(weights.items())))
return model
@ -660,10 +649,8 @@ class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cpu"):
dtype=mx.float32):
self.dtype = dtype
self.device = device
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
@ -673,30 +660,59 @@ class WanVAE:
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device)
self.mean = mx.array(mean, dtype=dtype)
self.std = mx.array(std, dtype=dtype)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
Returns: List of encoded videos in [C, T, H, W] format.
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
encoded = []
for video in videos:
# Convert CTHW -> BTHWC
x = mx.expand_dims(video, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Encode
z = self.model.encode(x, self.scale)[0] # Get mu only
# Convert back BTHWC -> CTHW and remove batch dimension
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
z = z.squeeze(0) # Remove batch dimension -> CTHW
encoded.append(z.astype(mx.float32))
return encoded
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
"""
zs: A list of latent codes each with shape [C, T, H, W].
Returns: List of decoded videos in [C, T, H, W] format.
"""
decoded = []
for z in zs:
# Convert CTHW -> BTHWC
x = mx.expand_dims(z, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Decode
x = self.model.decode(x, self.scale)
# Convert back BTHWC -> CTHW and remove batch dimension
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
x = x.squeeze(0) # Remove batch dimension -> CTHW
# Clamp values
x = mx.clip(x, -1, 1)
decoded.append(x.astype(mx.float32))
return decoded

View File

@ -1,719 +0,0 @@
# vae_mlx_final.py
import logging
from typing import Optional, List, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
__all__ = [
'WanVAE',
]
CACHE_T = 2
debug_line = 0
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolution for MLX.
Expects input in BTHWC format (batch, time, height, width, channels).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Padding order: (W, W, H, H, T, 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def __call__(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
padding[4] -= cache_x.shape[1]
# Pad in BTHWC format
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
(padding[0], padding[1]), (0, 0)]
x = mx.pad(x, pad_width)
result = super().__call__(x)
return result
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=False, images=True, bias=False):
super().__init__()
self.channel_first = channel_first
self.images = images
self.scale = dim**0.5
# Just keep as 1D - let broadcasting do its magic
self.gamma = mx.ones((dim,))
self.bias = mx.zeros((dim,)) if bias else 0.
def __call__(self, x):
# F.normalize in PyTorch does L2 normalization, not RMS!
# For NHWC/BTHWC format, normalize along the last axis
# L2 norm: sqrt(sum(x^2))
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
x = x / norm
return x * self.scale * self.gamma + self.bias
class Upsample(nn.Module):
"""
Upsampling layer that matches PyTorch's behavior.
"""
def __init__(self, scale_factor, mode='nearest-exact'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode # mode is now unused, but kept for signature consistency
def __call__(self, x):
# For NHWC format (n, h, w, c)
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
# is equivalent to a simple repeat operation. The previous coordinate-based
# sampling was not correct for this model and caused the divergence.
scale_h, scale_w = self.scale_factor
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
return out
class AsymmetricPad(nn.Module):
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
def __init__(self, pad_width: tuple):
super().__init__()
self.pad_width = pad_width
def __call__(self, x):
return mx.pad(x, self.pad_width)
# Update your Resample class to use 'nearest-exact'
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
# --- CORRECTED PADDING LOGIC ---
elif mode == 'downsample2d':
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
# Use the new AsymmetricPad module.
# Pad width for NHWC format is ((N), (H), (W), (C))
# Pad H with (top, bottom) and W with (left, right)
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
elif mode == 'downsample3d':
# The spatial downsampling part uses the same logic
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
self.resample = nn.Sequential(pad_layer, conv_layer)
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
# The __call__ method logic remains unchanged from your original code
b, t, h, w, c = x.shape
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
cache_x = mx.concatenate([
mx.zeros_like(cache_x), cache_x
], axis=1)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, t, h, w, 2, c)
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
x = x.reshape(b, t * 2, h, w, c)
t = x.shape[1]
x = x.reshape(b * t, h, w, c)
x = self.resample(x)
_, h_new, w_new, c_new = x.shape
x = x.reshape(b, t, h_new, w_new, c_new)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_idx[0] += 1
else:
cache_x = x[:, -1:, :, :, :]
x = self.time_conv(
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
CausalConv3d(out_dim, out_dim, 3, padding=1)
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def __call__(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for i, layer in enumerate(self.residual.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
self.proj.weight = mx.zeros_like(self.proj.weight)
def __call__(self, x):
# x is in BTHWC format
identity = x
b, t, h, w, c = x.shape
x = x.reshape(b * t, h, w, c) # Combine batch and time
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
qkv = qkv.reshape(b * t, h * w, 3 * c)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape for attention
q = q.reshape(b * t, h * w, c)
k = k.reshape(b * t, h * w, c)
v = v.reshape(b * t, h * w, c)
# Scaled dot product attention
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
scores = (q @ k.transpose(0, 2, 1)) * scale
weights = mx.softmax(scores, axis=-1)
x = weights @ v
x = x.reshape(b * t, h, w, c)
# output
x = self.proj(x)
x = x.reshape(b, t, h, w, c)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[-1], dims[-1], dropout),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1], dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], z_dim, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for i, layer in enumerate(self.downsamples.layers):
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(dims[-1], images=False),
nn.SiLU(),
CausalConv3d(dims[-1], 3, 3, padding=1)
)
def __call__(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle.layers:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples.layers:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for i, layer in enumerate(self.head.layers):
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:, :, :, :]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate([
feat_cache[idx][:, -1:, :, :, :], cache_x
], axis=1)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for name, module in model.named_modules():
if isinstance(module, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, scale):
# x is in BTHWC format
self.clear_cache()
## cache
t = x.shape[1]
iter_ = 1 + (t - 1) // 4
## Split encode input x by time into 1, 4, 4, 4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :1, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = mx.concatenate([out, out_], axis=1)
z = self.conv1(out)
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
if isinstance(scale[0], mx.array):
# Reshape scale for broadcasting in BTHWC format
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
mu = (mu - scale_mean) * scale_std
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu, log_var
def decode(self, z, scale):
# z is in BTHWC format
self.clear_cache()
if isinstance(scale[0], mx.array):
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
z = z / scale_std + scale_mean
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[1]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, i:i + 1, :, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = mx.concatenate([out, out_], axis=1)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = mx.exp(0.5 * log_var)
eps = mx.random.normal(std.shape)
return eps * std + mu
def __call__(self, x):
mu, log_var = self.encode(x, self.scale)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z, self.scale)
return x_recon, mu, log_var
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs, self.scale)
if deterministic:
return mu
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
return mu + std * mx.random.normal(std.shape)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
model = WanVAE_(**cfg)
# load checkpoint
if pretrained_path:
logging.info(f'loading {pretrained_path}')
weights = mx.load(pretrained_path)
model.update(tree_unflatten(list(weights.items())))
return model
class WanVAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=mx.float32):
self.dtype = dtype
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = mx.array(mean, dtype=dtype)
self.std = mx.array(std, dtype=dtype)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
Returns: List of encoded videos in [C, T, H, W] format.
"""
encoded = []
for video in videos:
# Convert CTHW -> BTHWC
x = mx.expand_dims(video, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Encode
z = self.model.encode(x, self.scale)[0] # Get mu only
# Convert back BTHWC -> CTHW and remove batch dimension
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
z = z.squeeze(0) # Remove batch dimension -> CTHW
encoded.append(z.astype(mx.float32))
return encoded
def decode(self, zs):
"""
zs: A list of latent codes each with shape [C, T, H, W].
Returns: List of decoded videos in [C, T, H, W] format.
"""
decoded = []
for z in zs:
# Convert CTHW -> BTHWC
x = mx.expand_dims(z, axis=0) # Add batch dimension
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# Decode
x = self.model.decode(x, self.scale)
# Convert back BTHWC -> CTHW and remove batch dimension
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
x = x.squeeze(0) # Remove batch dimension -> CTHW
# Clamp values
x = mx.clip(x, -1, 1)
decoded.append(x.astype(mx.float32))
return decoded

View File

@ -4,7 +4,7 @@ from safetensors.torch import save_file
from pathlib import Path
import json
from wan.modules.t5_mlx import T5Model
from wan.modules.t5 import T5Model
def convert_pickle_to_safetensors(

View File

@ -11,11 +11,11 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .modules.model_mlx import WanModel
from .modules.t5_mlx import T5EncoderModel
from .modules.vae_mlx import WanVAE
from .utils.fm_solvers_mlx import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from .utils.fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .wan_model_io import load_wan_from_safetensors

View File

@ -1,6 +1,6 @@
from .fm_solvers_mlx import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',