Clean up
@ -66,7 +66,7 @@ This repository currently supports two Text-to-Video models (1.3B and 14B) and t
|
||||
|
||||
##### (1) Example:
|
||||
```
|
||||
python generate_mlx.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
|
||||
python generate.py --task t2v-1.3B --size "480*832" --frame_num 16 --sample_steps 25 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --prompt "Lion running under snow in Samarkand" --save_file output_video_mlx.mp4
|
||||
```
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 1.7 MiB |
Before Width: | Height: | Size: 516 KiB |
Before Width: | Height: | Size: 871 KiB |
Before Width: | Height: | Size: 294 KiB |
Before Width: | Height: | Size: 1.5 MiB |
Before Width: | Height: | Size: 628 KiB |
Before Width: | Height: | Size: 208 KiB |
@ -10,7 +10,7 @@ from PIL import Image
|
||||
|
||||
import wan
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
||||
from wan.utils.utils_mlx_own import cache_video, cache_image, str2bool
|
||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-1.3B": {
|
@ -1,6 +0,0 @@
|
||||
|
||||
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
|
||||
|
||||
```bash
|
||||
bash ./test.sh <local model dir> <gpu number>
|
||||
```
|
@ -1,113 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
if [ "$#" -eq 2 ]; then
|
||||
MODEL_DIR=$(realpath "$1")
|
||||
GPUS=$2
|
||||
else
|
||||
echo "Usage: $0 <local model dir> <gpu number>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$REPO_ROOT" || exit 1
|
||||
|
||||
PY_FILE=./generate.py
|
||||
|
||||
|
||||
function t2v_1_3B() {
|
||||
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
|
||||
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
|
||||
if [ -n "${DASH_API_KEY+x}" ]; then
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
|
||||
else
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
|
||||
fi
|
||||
}
|
||||
|
||||
function t2v_14B() {
|
||||
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
}
|
||||
|
||||
|
||||
|
||||
function t2i_14B() {
|
||||
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
|
||||
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
}
|
||||
|
||||
|
||||
function i2v_14B_480p() {
|
||||
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
|
||||
|
||||
if [ -n "${DASH_API_KEY+x}" ]; then
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
|
||||
else
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
function i2v_14B_720p() {
|
||||
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
|
||||
|
||||
# 1-GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
|
||||
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
|
||||
|
||||
# Multiple GPU Test
|
||||
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
}
|
||||
|
||||
|
||||
t2i_14B
|
||||
t2v_1_3B
|
||||
t2v_14B
|
||||
i2v_14B_480p
|
||||
i2v_14B_720p
|
@ -1,2 +1,2 @@
|
||||
from . import configs, modules
|
||||
from .text2video_mlx import WanT2V
|
||||
from .text2video import WanT2V
|
||||
|
@ -1,7 +1,7 @@
|
||||
from .model_mlx import WanModel
|
||||
from .t5_mlx import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vae_mlx import WanVAE
|
||||
from .vae import WanVAE
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Modified from transformers.models.t5.modeling_t5
|
||||
# Modified from transformers.models.t5.modeling_t5 for MLX
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
@ -18,59 +20,41 @@ __all__ = [
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5Model):
|
||||
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.float32]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
@ -83,7 +67,7 @@ class T5Attention(nn.Module):
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
@ -91,50 +75,72 @@ class T5Attention(nn.Module):
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
@ -142,7 +148,6 @@ class T5FeedForward(nn.Module):
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
@ -150,8 +155,8 @@ class T5SelfAttention(nn.Module):
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
@ -167,16 +172,15 @@ class T5SelfAttention(nn.Module):
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
@ -184,8 +188,8 @@ class T5CrossAttention(nn.Module):
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5CrossAttention, self).__init__()
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
@ -203,14 +207,14 @@ class T5CrossAttention(nn.Module):
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def forward(self,
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
@ -219,9 +223,8 @@ class T5CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
@ -230,42 +233,55 @@ class T5RelativeEmbedding(nn.Module):
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
@ -275,8 +291,8 @@ class T5Encoder(nn.Module):
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Encoder, self).__init__()
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
@ -286,25 +302,25 @@ class T5Encoder(nn.Module):
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
@ -313,7 +329,6 @@ class T5Encoder(nn.Module):
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
@ -323,8 +338,8 @@ class T5Decoder(nn.Module):
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Decoder, self).__init__()
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
@ -334,34 +349,35 @@ class T5Decoder(nn.Module):
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.size()
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
@ -370,7 +386,6 @@ class T5Decoder(nn.Module):
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
@ -381,8 +396,8 @@ class T5Model(nn.Module):
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5Model, self).__init__()
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
@ -402,23 +417,63 @@ class T5Model(nn.Module):
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
@ -438,11 +493,11 @@ def _t5(name,
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
with torch.device(device):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
model = model.to(dtype=dtype, device=device)
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
@ -464,50 +519,99 @@ def umt5_xxl(**kwargs):
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1)
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
dtype=torch.float32,
|
||||
device='mps' if torch.backends.mps.is_available() else 'cpu',
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
shard_fn=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False,
|
||||
dtype=dtype,
|
||||
device=device).eval().requires_grad_(False)
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
self.model = model
|
||||
if shard_fn is not None:
|
||||
self.model = shard_fn(self.model, sync_module_states=False)
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
# init tokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
||||
return_tokenizer=False)
|
||||
|
||||
def __call__(self, texts, device):
|
||||
ids, mask = self.tokenizer(
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
return [u[:v] for u, v in zip(context, seq_lens)]
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
@ -1,617 +0,0 @@
|
||||
# Modified from transformers.models.t5.modeling_t5 for MLX
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False)
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
@ -1,13 +1,11 @@
|
||||
# Original PyTorch implementation of Wan VAE
|
||||
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
@ -17,66 +15,89 @@ CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
def debug(name, x):
|
||||
global debug_line
|
||||
print(f"LINE {debug_line}: {name}: shape = {tuple(x.shape)}, mean = {x.mean().item():.4f}, std = {x.std().item():.4f}")
|
||||
debug_line += 1
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
Causal 3d convolution for MLX.
|
||||
Expects input in BTHWC format (batch, time, height, width, channels).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Padding order: (W, W, H, H, T, 0)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
def __call__(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
|
||||
padding[4] -= cache_x.shape[1]
|
||||
|
||||
result = super().forward(x)
|
||||
debug("TORCH x after conv3d", result)
|
||||
# Pad in BTHWC format
|
||||
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
|
||||
(padding[0], padding[1]), (0, 0)]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
result = super().__call__(x)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
def __init__(self, dim, channel_first=False, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.images = images
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
# Just keep as 1D - let broadcasting do its magic
|
||||
self.gamma = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,)) if bias else 0.
|
||||
|
||||
def __call__(self, x):
|
||||
# F.normalize in PyTorch does L2 normalization, not RMS!
|
||||
# For NHWC/BTHWC format, normalize along the last axis
|
||||
# L2 norm: sqrt(sum(x^2))
|
||||
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
|
||||
x = x / norm
|
||||
return x * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
Upsampling layer that matches PyTorch's behavior.
|
||||
"""
|
||||
result = super().forward(x.float()).type_as(x)
|
||||
debug("TORCH x after upsample", result)
|
||||
return result
|
||||
def __init__(self, scale_factor, mode='nearest-exact'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode # mode is now unused, but kept for signature consistency
|
||||
|
||||
def __call__(self, x):
|
||||
# For NHWC format (n, h, w, c)
|
||||
|
||||
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
|
||||
# is equivalent to a simple repeat operation. The previous coordinate-based
|
||||
# sampling was not correct for this model and caused the divergence.
|
||||
|
||||
scale_h, scale_w = self.scale_factor
|
||||
|
||||
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
|
||||
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
|
||||
|
||||
return out
|
||||
|
||||
class AsymmetricPad(nn.Module):
|
||||
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
|
||||
def __init__(self, pad_width: tuple):
|
||||
super().__init__()
|
||||
self.pad_width = pad_width
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(x, self.pad_width)
|
||||
|
||||
# Update your Resample class to use 'nearest-exact'
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
@ -90,30 +111,42 @@ class Resample(nn.Module):
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
# --- CORRECTED PADDING LOGIC ---
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
|
||||
# Use the new AsymmetricPad module.
|
||||
# Pad width for NHWC format is ((N), (H), (W), (C))
|
||||
# Pad H with (top, bottom) and W with (left, right)
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
# The spatial downsampling part uses the same logic
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# The __call__ method logic remains unchanged from your original code
|
||||
b, t, h, w, c = x.shape
|
||||
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@ -121,84 +154,50 @@ class Resample(nn.Module):
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
mx.zeros_like(cache_x), cache_x
|
||||
], axis=1)
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
debug("TORCH x after time_conv", x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
debug("TORCH x after time_conv with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = x.reshape(b, t, h, w, 2, c)
|
||||
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
|
||||
x = x.reshape(b, t * 2, h, w, c)
|
||||
|
||||
t = x.shape[1]
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
debug("TORCH x after resample", x)
|
||||
|
||||
_, h_new, w_new, c_new = x.shape
|
||||
x = x.reshape(b, t, h_new, w_new, c_new)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
cache_x = x[:, -1:, :, :, :]
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
@ -209,34 +208,34 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
debug("TORCH x after shortcut", h)
|
||||
for layer in self.residual:
|
||||
|
||||
for i, layer in enumerate(self.residual.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after residual block", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after residual block", x)
|
||||
return x + h
|
||||
|
||||
|
||||
@ -255,32 +254,34 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
self.proj.weight = mx.zeros_like(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
def __call__(self, x):
|
||||
# x is in BTHWC format
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
b, t, h, w, c = x.shape
|
||||
x = x.reshape(b * t, h, w, c) # Combine batch and time
|
||||
x = self.norm(x)
|
||||
debug("TORCH x after norm", x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
||||
-1).permute(0, 1, 3,
|
||||
2).contiguous().chunk(
|
||||
3, dim=-1)
|
||||
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
|
||||
qkv = qkv.reshape(b * t, h * w, 3 * c)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
# Reshape for attention
|
||||
q = q.reshape(b * t, h * w, c)
|
||||
k = k.reshape(b * t, h * w, c)
|
||||
v = v.reshape(b * t, h * w, c)
|
||||
|
||||
# Scaled dot product attention
|
||||
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
|
||||
scores = (q @ k.transpose(0, 2, 1)) * scale
|
||||
weights = mx.softmax(scores, axis=-1)
|
||||
x = weights @ v
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
debug("TORCH x after proj", x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
x = x.reshape(b, t, h, w, c)
|
||||
return x + identity
|
||||
|
||||
|
||||
@ -321,78 +322,68 @@ class Encoder3d(nn.Module):
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
ResidualBlock(dims[-1], dims[-1], dropout),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1], dropout)
|
||||
)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
debug("TORCH x after conv1 with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
debug("TORCH x after conv1", x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples):
|
||||
for i, layer in enumerate(self.downsamples.layers):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after downsample layer", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after downsample layer", x)
|
||||
return x
|
||||
|
||||
|
||||
@ -423,8 +414,10 @@ class Decoder3d(nn.Module):
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout)
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
@ -443,77 +436,66 @@ class Decoder3d(nn.Module):
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
debug("TORCH x after conv1 with cache", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
debug("TORCH x after conv1", x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after middle layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after middle layer", x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
for layer in self.upsamples.layers:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
debug("TORCH x after upsample layer", x)
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after upsample layer", x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
debug("TORCH x after head layer", x)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
debug("TORCH x after head layer", x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@ -545,78 +527,85 @@ class WanVAE_(nn.Module):
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
# x is in BTHWC format
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
t = x.shape[1]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
## Split encode input x by time into 1, 4, 4, 4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
x[:, :1, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
|
||||
z = self.conv1(out)
|
||||
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
|
||||
|
||||
if isinstance(scale[0], mx.array):
|
||||
# Reshape scale for broadcasting in BTHWC format
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
mu = (mu - scale_mean) * scale_std
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z, scale):
|
||||
# z is in BTHWC format
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
if isinstance(scale[0], mx.array):
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
z = z / scale_std + scale_mean
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
iter_ = z.shape[1]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
std = mx.exp(0.5 * log_var)
|
||||
eps = mx.random.normal(std.shape)
|
||||
return eps * std + mu
|
||||
|
||||
def __call__(self, x):
|
||||
mu, log_var = self.encode(x, self.scale)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z, self.scale)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
mu, log_var = self.encode(imgs, self.scale)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
|
||||
return mu + std * mx.random.normal(std.shape)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
@ -628,7 +617,7 @@ class WanVAE_(nn.Module):
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
||||
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
|
||||
"""
|
||||
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
||||
"""
|
||||
@ -644,13 +633,13 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
with torch.device('meta'):
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if pretrained_path:
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
model.load_state_dict(
|
||||
torch.load(pretrained_path, map_location=device), assign=True)
|
||||
weights = mx.load(pretrained_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
return model
|
||||
|
||||
@ -660,10 +649,8 @@ class WanVAE:
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=torch.float,
|
||||
device="cpu"):
|
||||
dtype=mx.float32):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
@ -673,30 +660,59 @@ class WanVAE:
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
||||
self.std = torch.tensor(std, dtype=dtype, device=device)
|
||||
self.mean = mx.array(mean, dtype=dtype)
|
||||
self.std = mx.array(std, dtype=dtype)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
).eval().requires_grad_(False).to(device)
|
||||
)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
Returns: List of encoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
||||
for u in videos
|
||||
]
|
||||
encoded = []
|
||||
for video in videos:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(video, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Encode
|
||||
z = self.model.encode(x, self.scale)[0] # Get mu only
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
z = z.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
encoded.append(z.astype(mx.float32))
|
||||
|
||||
return encoded
|
||||
|
||||
def decode(self, zs):
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0),
|
||||
self.scale).float().clamp_(-1, 1).squeeze(0)
|
||||
for u in zs
|
||||
]
|
||||
"""
|
||||
zs: A list of latent codes each with shape [C, T, H, W].
|
||||
Returns: List of decoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
decoded = []
|
||||
for z in zs:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(z, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Decode
|
||||
x = self.model.decode(x, self.scale)
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
x = x.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
# Clamp values
|
||||
x = mx.clip(x, -1, 1)
|
||||
|
||||
decoded.append(x.astype(mx.float32))
|
||||
|
||||
return decoded
|
@ -1,719 +0,0 @@
|
||||
# vae_mlx_final.py
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
]
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
debug_line = 0
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolution for MLX.
|
||||
Expects input in BTHWC format (batch, time, height, width, channels).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Padding order: (W, W, H, H, T, 0)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def __call__(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=1) # Concat along time axis
|
||||
padding[4] -= cache_x.shape[1]
|
||||
|
||||
# Pad in BTHWC format
|
||||
pad_width = [(0, 0), (padding[4], padding[5]), (padding[2], padding[3]),
|
||||
(padding[0], padding[1]), (0, 0)]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
result = super().__call__(x)
|
||||
return result
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=False, images=True, bias=False):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.images = images
|
||||
self.scale = dim**0.5
|
||||
|
||||
# Just keep as 1D - let broadcasting do its magic
|
||||
self.gamma = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,)) if bias else 0.
|
||||
|
||||
def __call__(self, x):
|
||||
# F.normalize in PyTorch does L2 normalization, not RMS!
|
||||
# For NHWC/BTHWC format, normalize along the last axis
|
||||
# L2 norm: sqrt(sum(x^2))
|
||||
norm = mx.sqrt(mx.sum(x * x, axis=-1, keepdims=True) + 1e-6)
|
||||
x = x / norm
|
||||
return x * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Upsampling layer that matches PyTorch's behavior.
|
||||
"""
|
||||
def __init__(self, scale_factor, mode='nearest-exact'):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode # mode is now unused, but kept for signature consistency
|
||||
|
||||
def __call__(self, x):
|
||||
# For NHWC format (n, h, w, c)
|
||||
|
||||
# NOTE: For an integer scale_factor like 2.0, PyTorch's 'nearest-exact'
|
||||
# is equivalent to a simple repeat operation. The previous coordinate-based
|
||||
# sampling was not correct for this model and caused the divergence.
|
||||
|
||||
scale_h, scale_w = self.scale_factor
|
||||
|
||||
out = mx.repeat(x, int(scale_h), axis=1) # Repeat along H dimension
|
||||
out = mx.repeat(out, int(scale_w), axis=2) # Repeat along W dimension
|
||||
|
||||
return out
|
||||
|
||||
class AsymmetricPad(nn.Module):
|
||||
"""A module to apply asymmetric padding, compatible with nn.Sequential."""
|
||||
def __init__(self, pad_width: tuple):
|
||||
super().__init__()
|
||||
self.pad_width = pad_width
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.pad(x, self.pad_width)
|
||||
|
||||
# Update your Resample class to use 'nearest-exact'
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
# --- CORRECTED PADDING LOGIC ---
|
||||
elif mode == 'downsample2d':
|
||||
# Replicate PyTorch's ZeroPad2d((0, 1, 0, 1)) + Conv2d(stride=2)
|
||||
# Use the new AsymmetricPad module.
|
||||
# Pad width for NHWC format is ((N), (H), (W), (C))
|
||||
# Pad H with (top, bottom) and W with (left, right)
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
elif mode == 'downsample3d':
|
||||
# The spatial downsampling part uses the same logic
|
||||
pad_layer = AsymmetricPad(pad_width=((0, 0), (0, 1), (0, 1), (0, 0)))
|
||||
conv_layer = nn.Conv2d(dim, dim, 3, stride=(2, 2), padding=0)
|
||||
self.resample = nn.Sequential(pad_layer, conv_layer)
|
||||
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# The __call__ method logic remains unchanged from your original code
|
||||
b, t, h, w, c = x.shape
|
||||
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = mx.concatenate([
|
||||
mx.zeros_like(cache_x), cache_x
|
||||
], axis=1)
|
||||
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, t, h, w, 2, c)
|
||||
x = mx.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=2)
|
||||
x = x.reshape(b, t * 2, h, w, c)
|
||||
|
||||
t = x.shape[1]
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
_, h_new, w_new, c_new = x.shape
|
||||
x = x.reshape(b, t, h_new, w_new, c_new)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, -1:, :, :, :]
|
||||
x = self.time_conv(
|
||||
mx.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
|
||||
for i, layer in enumerate(self.residual.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
self.proj.weight = mx.zeros_like(self.proj.weight)
|
||||
|
||||
def __call__(self, x):
|
||||
# x is in BTHWC format
|
||||
identity = x
|
||||
b, t, h, w, c = x.shape
|
||||
x = x.reshape(b * t, h, w, c) # Combine batch and time
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x) # Output: (b*t, h, w, 3*c)
|
||||
qkv = qkv.reshape(b * t, h * w, 3 * c)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Reshape for attention
|
||||
q = q.reshape(b * t, h * w, c)
|
||||
k = k.reshape(b * t, h * w, c)
|
||||
v = v.reshape(b * t, h * w, c)
|
||||
|
||||
# Scaled dot product attention
|
||||
scale = 1.0 / mx.sqrt(mx.array(c, dtype=q.dtype))
|
||||
scores = (q @ k.transpose(0, 2, 1)) * scale
|
||||
weights = mx.softmax(scores, axis=-1)
|
||||
x = weights @ v
|
||||
x = x.reshape(b * t, h, w, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = x.reshape(b, t, h, w, c)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[-1], dims[-1], dropout),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1], dropout)
|
||||
)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for i, layer in enumerate(self.downsamples.layers):
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout)
|
||||
)
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(dims[-1], images=False),
|
||||
nn.SiLU(),
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1)
|
||||
)
|
||||
|
||||
def __call__(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle.layers:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples.layers:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for i, layer in enumerate(self.head.layers):
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, -CACHE_T:, :, :, :]
|
||||
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([
|
||||
feat_cache[idx][:, -1:, :, :, :], cache_x
|
||||
], axis=1)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x, scale):
|
||||
# x is in BTHWC format
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[1]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## Split encode input x by time into 1, 4, 4, 4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :1, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, 1 + 4 * (i - 1):1 + 4 * i, :, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
|
||||
z = self.conv1(out)
|
||||
mu, log_var = mx.split(z, 2, axis=-1) # Split along channel dimension
|
||||
|
||||
if isinstance(scale[0], mx.array):
|
||||
# Reshape scale for broadcasting in BTHWC format
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
mu = (mu - scale_mean) * scale_std
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z, scale):
|
||||
# z is in BTHWC format
|
||||
self.clear_cache()
|
||||
if isinstance(scale[0], mx.array):
|
||||
scale_mean = scale[0].reshape(1, 1, 1, 1, self.z_dim)
|
||||
scale_std = scale[1].reshape(1, 1, 1, 1, self.z_dim)
|
||||
z = z / scale_std + scale_mean
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[1]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, i:i + 1, :, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = mx.concatenate([out, out_], axis=1)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = mx.exp(0.5 * log_var)
|
||||
eps = mx.random.normal(std.shape)
|
||||
return eps * std + mu
|
||||
|
||||
def __call__(self, x):
|
||||
mu, log_var = self.encode(x, self.scale)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z, self.scale)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs, self.scale)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = mx.exp(0.5 * mx.clip(log_var, -30.0, 20.0))
|
||||
return mu + std * mx.random.normal(std.shape)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
def _video_vae(pretrained_path=None, z_dim=None, **kwargs):
|
||||
"""
|
||||
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
||||
"""
|
||||
# params
|
||||
cfg = dict(
|
||||
dim=96,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
|
||||
# init model
|
||||
model = WanVAE_(**cfg)
|
||||
|
||||
# load checkpoint
|
||||
if pretrained_path:
|
||||
logging.info(f'loading {pretrained_path}')
|
||||
weights = mx.load(pretrained_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class WanVAE:
|
||||
|
||||
def __init__(self,
|
||||
z_dim=16,
|
||||
vae_pth='cache/vae_step_411000.pth',
|
||||
dtype=mx.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = mx.array(mean, dtype=dtype)
|
||||
self.std = mx.array(std, dtype=dtype)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = _video_vae(
|
||||
pretrained_path=vae_pth,
|
||||
z_dim=z_dim,
|
||||
)
|
||||
|
||||
def encode(self, videos):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
Returns: List of encoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
encoded = []
|
||||
for video in videos:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(video, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Encode
|
||||
z = self.model.encode(x, self.scale)[0] # Get mu only
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
z = z.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
z = z.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
encoded.append(z.astype(mx.float32))
|
||||
|
||||
return encoded
|
||||
|
||||
def decode(self, zs):
|
||||
"""
|
||||
zs: A list of latent codes each with shape [C, T, H, W].
|
||||
Returns: List of decoded videos in [C, T, H, W] format.
|
||||
"""
|
||||
decoded = []
|
||||
for z in zs:
|
||||
# Convert CTHW -> BTHWC
|
||||
x = mx.expand_dims(z, axis=0) # Add batch dimension
|
||||
x = x.transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
|
||||
|
||||
# Decode
|
||||
x = self.model.decode(x, self.scale)
|
||||
|
||||
# Convert back BTHWC -> CTHW and remove batch dimension
|
||||
x = x.transpose(0, 4, 1, 2, 3) # BTHWC -> BCTHW
|
||||
x = x.squeeze(0) # Remove batch dimension -> CTHW
|
||||
|
||||
# Clamp values
|
||||
x = mx.clip(x, -1, 1)
|
||||
|
||||
decoded.append(x.astype(mx.float32))
|
||||
|
||||
return decoded
|
@ -4,7 +4,7 @@ from safetensors.torch import save_file
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from wan.modules.t5_mlx import T5Model
|
||||
from wan.modules.t5 import T5Model
|
||||
|
||||
|
||||
def convert_pickle_to_safetensors(
|
||||
|
@ -11,11 +11,11 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .modules.model_mlx import WanModel
|
||||
from .modules.t5_mlx import T5EncoderModel
|
||||
from .modules.vae_mlx import WanVAE
|
||||
from .utils.fm_solvers_mlx import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
||||
from .utils.fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .wan_model_io import load_wan_from_safetensors
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .fm_solvers_mlx import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
retrieve_timesteps)
|
||||
from .fm_solvers_unipc_mlx import FlowUniPCMultistepScheduler
|
||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
__all__ = [
|
||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||
|