mlx-examples/video/Wan2.1/wan/modules/t5.py
2025-07-28 17:07:26 -07:00

617 lines
20 KiB
Python

# Modified from transformers.models.t5.modeling_t5 for MLX
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
from typing import Optional, Tuple, List
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_unflatten
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == mx.float16:
# Use same clamping as PyTorch for consistency
clamp = 65504.0 # max value for float16
return mx.clip(x, -clamp, clamp)
return x
class GELU(nn.Module):
def __call__(self, x):
return 0.5 * x * (1.0 + mx.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x):
# Match PyTorch's approach: convert to float32 for stability
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
x_norm = x_float * mx.rsqrt(variance + self.eps)
# Convert back to original dtype
if x.dtype == mx.float16:
x_norm = x_norm.astype(mx.float16)
return self.weight * x_norm
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, l1, _ = x.shape
_, l2, _ = context.shape
n, c = self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, l1, n, c)
k = self.k(context).reshape(b, l2, n, c)
v = self.v(context).reshape(b, l2, n, c)
# transpose for attention: [B, N, L, C]
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# compute attention (T5 does not use scaling)
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
# add position bias if provided
if pos_bias is not None:
attn = attn + pos_bias
# apply mask
if mask is not None:
if mask.ndim == 2:
# [B, L2] -> [B, 1, 1, L2]
mask = mask[:, None, None, :]
elif mask.ndim == 3:
# [B, L1, L2] -> [B, 1, L1, L2]
mask = mask[:, None, :, :]
# Use very negative value that works well with float16
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
attn = mx.where(mask == 0, min_value, attn)
# softmax and apply attention
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
attn = self.dropout(attn)
# apply attention to values
x = mx.matmul(attn, v) # [B, N, L1, C]
# transpose back and reshape
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
x = x.reshape(b, l1, -1)
# output projection
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = GELU()
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
gate = self.gate_act(self.gate_proj(x))
x = self.fc1(x) * gate
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def __call__(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def __call__(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.shape[1], x.shape[1])
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def __call__(self, lq, lk):
# Create relative position matrix
positions_q = mx.arange(lq)[:, None]
positions_k = mx.arange(lk)[None, :]
rel_pos = positions_k - positions_q
# Apply bucketing
rel_pos = self._relative_position_bucket(rel_pos)
# Get embeddings
rel_pos_embeds = self.embedding(rel_pos)
# Reshape to [1, N, Lq, Lk]
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
return rel_pos_embeds
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
# For large positions, use log scale
rel_pos_large = max_exact + (
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
# Combine small and large position buckets
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
if isinstance(vocab, nn.Embedding):
self.token_embedding = vocab
else:
self.token_embedding = nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = [
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.shape
# causal mask
if mask is None:
mask = mx.tril(mx.ones((1, s, s)))
elif mask.ndim == 2:
# Expand mask properly
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.shape[1],
x.shape[1]) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.0):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def init_mlx_weights(module, key):
"""Initialize weights for T5 model components to match PyTorch initialization"""
def normal(key, shape, std=1.0):
return mx.random.normal(key, shape) * std
if isinstance(module, T5LayerNorm):
module.weight = mx.ones_like(module.weight)
elif isinstance(module, nn.Embedding):
key = mx.random.split(key, 1)[0]
module.weight = normal(key, module.weight.shape, std=1.0)
elif isinstance(module, T5FeedForward):
# Match PyTorch initialization
key1, key2, key3 = mx.random.split(key, 3)
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
std=module.dim**-0.5)
module.fc1.weight = normal(key2, module.fc1.weight.shape,
std=module.dim**-0.5)
module.fc2.weight = normal(key3, module.fc2.weight.shape,
std=module.dim_ffn**-0.5)
elif isinstance(module, T5Attention):
# Match PyTorch initialization
key1, key2, key3, key4 = random.split(key, 4)
module.q.weight = normal(key1, module.q.weight.shape,
std=(module.dim * module.dim_attn)**-0.5)
module.k.weight = normal(key2, module.k.weight.shape,
std=module.dim**-0.5)
module.v.weight = normal(key3, module.v.weight.shape,
std=module.dim**-0.5)
module.o.weight = normal(key4, module.o.weight.shape,
std=(module.num_heads * module.dim_attn)**-0.5)
elif isinstance(module, T5RelativeEmbedding):
key = mx.random.split(key, 1)[0]
module.embedding.weight = normal(key, module.embedding.weight.shape,
std=(2 * module.num_buckets * module.num_heads)**-0.5)
elif isinstance(module, nn.Linear):
# Generic linear layer initialization
key = mx.random.split(key, 1)[0]
fan_in = module.weight.shape[1]
bound = 1.0 / math.sqrt(fan_in)
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
return module
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
model = model_cls(**kwargs)
# Initialize weights properly
key = mx.random.key(0)
model = init_mlx_weights(model, key)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
checkpoint_path=None,
tokenizer_path=None,
):
self.text_len = text_len
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False)
if checkpoint_path:
logging.info(f'loading {checkpoint_path}')
# Load weights - assuming MLX format checkpoint
weights = mx.load(checkpoint_path)
model.update(tree_unflatten(list(weights.items())))
self.model = model
# init tokenizer
from .tokenizers import HuggingfaceTokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
seq_len=text_len,
clean='whitespace')
def __call__(self, texts):
# Handle single string input
if isinstance(texts, str):
texts = [texts]
# Tokenize texts
tokenizer_output = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
# Handle different tokenizer output formats
if isinstance(tokenizer_output, tuple):
ids, mask = tokenizer_output
else:
# Assuming dict output with 'input_ids' and 'attention_mask'
ids = tokenizer_output['input_ids']
mask = tokenizer_output['attention_mask']
# Convert to MLX arrays if not already
if not isinstance(ids, mx.array):
ids = mx.array(ids)
if not isinstance(mask, mx.array):
mask = mx.array(mask)
# Get sequence lengths
seq_lens = mx.sum(mask > 0, axis=1)
# Run encoder
context = self.model(ids, mask)
# Return variable length outputs
# Convert seq_lens to Python list for indexing
if seq_lens.ndim == 0: # Single value
seq_lens_list = [seq_lens.item()]
else:
seq_lens_list = seq_lens.tolist()
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
# Utility function to convert PyTorch checkpoint to MLX
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
"""Convert PyTorch checkpoint to MLX format"""
import torch
# Load PyTorch checkpoint
pytorch_state = torch.load(pytorch_path, map_location='cpu')
# Convert to numpy then to MLX
mlx_state = {}
for key, value in pytorch_state.items():
if isinstance(value, torch.Tensor):
# Handle the key mapping if needed
mlx_key = key
# Convert tensor to MLX array
mlx_state[mlx_key] = mx.array(value.numpy())
# Save MLX checkpoint
mx.save(mlx_path, mlx_state)
return mlx_state