mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
787 lines
27 KiB
Python
787 lines
27 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
# MLX Implementation of WAN Model - True 1:1 Port from PyTorch
|
|
|
|
import math
|
|
from typing import List, Tuple, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
|
|
__all__ = ['WanModel']
|
|
|
|
|
|
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
|
"""Generate sinusoidal position embeddings."""
|
|
assert dim % 2 == 0
|
|
half = dim // 2
|
|
position = position.astype(mx.float32)
|
|
|
|
# Calculate sinusoidal embeddings
|
|
div_term = mx.power(10000, mx.arange(half).astype(mx.float32) / half)
|
|
sinusoid = mx.expand_dims(position, 1) / mx.expand_dims(div_term, 0)
|
|
|
|
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
|
|
|
|
|
def rope_params(max_seq_len: int, dim: int, theta: float = 10000) -> mx.array:
|
|
"""Generate RoPE (Rotary Position Embedding) parameters."""
|
|
assert dim % 2 == 0
|
|
positions = mx.arange(max_seq_len)
|
|
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
|
|
freqs = 1.0 / mx.power(theta, freqs)
|
|
|
|
# Outer product
|
|
freqs = mx.expand_dims(positions, 1) * mx.expand_dims(freqs, 0)
|
|
|
|
# Convert to complex representation
|
|
return mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
|
|
|
|
|
|
def rope_apply(x: mx.array, grid_sizes: mx.array, freqs: mx.array) -> mx.array:
|
|
"""Apply rotary position embeddings to input tensor."""
|
|
n, c_half = x.shape[2], x.shape[3] // 2
|
|
|
|
# Split frequencies for different dimensions
|
|
c_split = c_half - 2 * (c_half // 3)
|
|
freqs_splits = [
|
|
freqs[:, :c_split],
|
|
freqs[:, c_split:c_split + c_half // 3],
|
|
freqs[:, c_split + c_half // 3:]
|
|
]
|
|
|
|
output = []
|
|
for i in range(grid_sizes.shape[0]):
|
|
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
|
|
seq_len = f * h * w
|
|
|
|
# Extract sequence for current sample
|
|
x_i = x[i, :seq_len].astype(mx.float32)
|
|
x_i = x_i.reshape(seq_len, n, -1, 2)
|
|
|
|
# Prepare frequency tensors
|
|
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
|
|
freqs_f = mx.broadcast_to(freqs_f, (f, h, w, freqs_f.shape[-2], 2))
|
|
|
|
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
|
|
freqs_h = mx.broadcast_to(freqs_h, (f, h, w, freqs_h.shape[-2], 2))
|
|
|
|
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
|
|
freqs_w = mx.broadcast_to(freqs_w, (f, h, w, freqs_w.shape[-2], 2))
|
|
|
|
# Concatenate and reshape frequencies
|
|
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=-2)
|
|
freqs_i = freqs_i.reshape(seq_len, 1, -1, 2)
|
|
|
|
# Apply rotary embedding
|
|
x_real = x_i[..., 0]
|
|
x_imag = x_i[..., 1]
|
|
freqs_cos = freqs_i[..., 0]
|
|
freqs_sin = freqs_i[..., 1]
|
|
|
|
x_rotated_real = x_real * freqs_cos - x_imag * freqs_sin
|
|
x_rotated_imag = x_real * freqs_sin + x_imag * freqs_cos
|
|
|
|
x_i = mx.stack([x_rotated_real, x_rotated_imag], axis=-1).reshape(seq_len, n, -1)
|
|
|
|
# Concatenate with remaining sequence if any
|
|
if seq_len < x.shape[1]:
|
|
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
|
|
|
|
output.append(x_i)
|
|
|
|
return mx.stack(output).astype(x.dtype)
|
|
|
|
|
|
class WanRMSNorm(nn.Module):
|
|
"""Root Mean Square Layer Normalization."""
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
self.weight = mx.ones((dim,))
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
# RMS normalization
|
|
variance = mx.mean(mx.square(x.astype(mx.float32)), axis=-1, keepdims=True)
|
|
x_normed = x * mx.rsqrt(variance + self.eps)
|
|
return (x_normed * self.weight).astype(x.dtype)
|
|
|
|
|
|
class WanLayerNorm(nn.Module):
|
|
"""Layer normalization."""
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
|
|
if elementwise_affine:
|
|
self.weight = mx.ones((dim,))
|
|
self.bias = mx.zeros((dim,))
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
# Standard layer normalization
|
|
x_float = x.astype(mx.float32)
|
|
mean = mx.mean(x_float, axis=-1, keepdims=True)
|
|
variance = mx.var(x_float, axis=-1, keepdims=True)
|
|
x_normed = (x_float - mean) * mx.rsqrt(variance + self.eps)
|
|
|
|
if self.elementwise_affine:
|
|
x_normed = x_normed * self.weight + self.bias
|
|
|
|
return x_normed.astype(x.dtype)
|
|
|
|
|
|
def mlx_attention(
|
|
q: mx.array,
|
|
k: mx.array,
|
|
v: mx.array,
|
|
q_lens: Optional[mx.array] = None,
|
|
k_lens: Optional[mx.array] = None,
|
|
dropout_p: float = 0.,
|
|
softmax_scale: Optional[float] = None,
|
|
q_scale: Optional[float] = None,
|
|
causal: bool = False,
|
|
window_size: Tuple[int, int] = (-1, -1),
|
|
deterministic: bool = False,
|
|
dtype: Optional[type] = None,
|
|
) -> mx.array:
|
|
"""
|
|
MLX implementation of scaled dot-product attention.
|
|
"""
|
|
# Get shapes
|
|
b, lq, n, d = q.shape
|
|
_, lk, _, _ = k.shape
|
|
|
|
# Scale queries if needed
|
|
if q_scale is not None:
|
|
q = q * q_scale
|
|
|
|
# Compute attention scores
|
|
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
|
|
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
|
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
|
|
|
# Compute attention scores
|
|
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
|
|
|
|
# Apply softmax scale if provided
|
|
if softmax_scale is not None:
|
|
scores = scores * softmax_scale
|
|
else:
|
|
# Default scaling by sqrt(d)
|
|
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
|
|
|
|
# Create attention mask
|
|
attn_mask = None
|
|
|
|
# Apply window size masking if specified
|
|
if window_size != (-1, -1):
|
|
left_window, right_window = window_size
|
|
window_mask = mx.zeros((lq, lk))
|
|
for i in range(lq):
|
|
start = max(0, i - left_window)
|
|
end = min(lk, i + right_window + 1)
|
|
window_mask[i, start:end] = 1
|
|
attn_mask = window_mask
|
|
|
|
# Apply causal masking if needed
|
|
if causal:
|
|
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
|
|
if attn_mask is None:
|
|
attn_mask = causal_mask
|
|
else:
|
|
attn_mask = mx.logical_and(attn_mask, causal_mask)
|
|
|
|
# Apply attention mask if present
|
|
if attn_mask is not None:
|
|
attn_mask = attn_mask.astype(scores.dtype)
|
|
scores = scores * attn_mask + (1 - attn_mask) * -1e4
|
|
|
|
# Apply attention mask if lengths are provided
|
|
if q_lens is not None or k_lens is not None:
|
|
if q_lens is not None:
|
|
mask = mx.arange(lq)[None, :] < q_lens[:, None]
|
|
mask = mask.astype(scores.dtype)
|
|
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
|
|
if k_lens is not None:
|
|
mask = mx.arange(lk)[None, :] < k_lens[:, None]
|
|
mask = mask.astype(scores.dtype)
|
|
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
|
|
|
|
# Apply softmax
|
|
max_scores = mx.max(scores, axis=-1, keepdims=True)
|
|
scores = scores - max_scores
|
|
exp_scores = mx.exp(scores)
|
|
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
|
|
attn = exp_scores / (sum_exp + 1e-6)
|
|
|
|
# Apply dropout if needed
|
|
if dropout_p > 0 and not deterministic:
|
|
raise NotImplementedError("Dropout not implemented in MLX version")
|
|
|
|
# Compute output
|
|
out = mx.matmul(attn, v) # [b, n, lq, d]
|
|
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
|
|
|
|
return out
|
|
|
|
|
|
class WanSelfAttention(nn.Module):
|
|
"""Self-attention module with RoPE and optional QK normalization."""
|
|
|
|
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
|
|
qk_norm: bool = True, eps: float = 1e-6):
|
|
super().__init__()
|
|
assert dim % num_heads == 0
|
|
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.window_size = window_size
|
|
self.qk_norm = qk_norm
|
|
self.eps = eps
|
|
|
|
# Linear projections
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
|
|
# Normalization layers
|
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
|
|
|
def __call__(self, x: mx.array, seq_lens: mx.array, grid_sizes: mx.array,
|
|
freqs: mx.array) -> mx.array:
|
|
b, s = x.shape[0], x.shape[1]
|
|
|
|
# Compute Q, K, V
|
|
q = self.q(x)
|
|
k = self.k(x)
|
|
v = self.v(x)
|
|
|
|
if self.qk_norm:
|
|
q = self.norm_q(q)
|
|
k = self.norm_k(k)
|
|
|
|
# Reshape for multi-head attention
|
|
q = q.reshape(b, s, self.num_heads, self.head_dim)
|
|
k = k.reshape(b, s, self.num_heads, self.head_dim)
|
|
v = v.reshape(b, s, self.num_heads, self.head_dim)
|
|
|
|
# Apply RoPE
|
|
q = rope_apply(q, grid_sizes, freqs)
|
|
k = rope_apply(k, grid_sizes, freqs)
|
|
|
|
# Apply attention
|
|
x = mlx_attention(q, k, v, k_lens=seq_lens, window_size=self.window_size)
|
|
|
|
# Reshape and project output
|
|
x = x.reshape(b, s, self.dim)
|
|
x = self.o(x)
|
|
|
|
return x
|
|
|
|
|
|
class WanT2VCrossAttention(WanSelfAttention):
|
|
"""Text-to-video cross attention."""
|
|
|
|
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
|
|
b = x.shape[0]
|
|
|
|
# Compute queries from x
|
|
q = self.q(x)
|
|
if self.qk_norm:
|
|
q = self.norm_q(q)
|
|
q = q.reshape(b, -1, self.num_heads, self.head_dim)
|
|
|
|
# Compute keys and values from context
|
|
k = self.k(context)
|
|
v = self.v(context)
|
|
if self.qk_norm:
|
|
k = self.norm_k(k)
|
|
k = k.reshape(b, -1, self.num_heads, self.head_dim)
|
|
v = v.reshape(b, -1, self.num_heads, self.head_dim)
|
|
|
|
# Apply attention
|
|
x = mlx_attention(q, k, v, k_lens=context_lens)
|
|
|
|
# Reshape and project output
|
|
x = x.reshape(b, -1, self.dim)
|
|
x = self.o(x)
|
|
|
|
return x
|
|
|
|
|
|
class WanI2VCrossAttention(WanSelfAttention):
|
|
"""Image-to-video cross attention."""
|
|
|
|
def __init__(self, dim: int, num_heads: int, window_size: Tuple[int, int] = (-1, -1),
|
|
qk_norm: bool = True, eps: float = 1e-6):
|
|
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
|
|
|
self.k_img = nn.Linear(dim, dim)
|
|
self.v_img = nn.Linear(dim, dim)
|
|
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
|
|
|
def __call__(self, x: mx.array, context: mx.array, context_lens: mx.array) -> mx.array:
|
|
# Split context into image and text parts
|
|
context_img = context[:, :257]
|
|
context = context[:, 257:]
|
|
|
|
b = x.shape[0]
|
|
|
|
# Compute queries
|
|
q = self.q(x)
|
|
if self.qk_norm:
|
|
q = self.norm_q(q)
|
|
q = q.reshape(b, -1, self.num_heads, self.head_dim)
|
|
|
|
# Compute keys and values for text
|
|
k = self.k(context)
|
|
v = self.v(context)
|
|
if self.qk_norm:
|
|
k = self.norm_k(k)
|
|
k = k.reshape(b, -1, self.num_heads, self.head_dim)
|
|
v = v.reshape(b, -1, self.num_heads, self.head_dim)
|
|
|
|
# Compute keys and values for image
|
|
k_img = self.k_img(context_img)
|
|
v_img = self.v_img(context_img)
|
|
if self.qk_norm:
|
|
k_img = self.norm_k_img(k_img)
|
|
k_img = k_img.reshape(b, -1, self.num_heads, self.head_dim)
|
|
v_img = v_img.reshape(b, -1, self.num_heads, self.head_dim)
|
|
|
|
# Apply attention
|
|
img_x = mlx_attention(q, k_img, v_img, k_lens=None)
|
|
x = mlx_attention(q, k, v, k_lens=context_lens)
|
|
|
|
# Combine and project
|
|
img_x = img_x.reshape(b, -1, self.dim)
|
|
x = x.reshape(b, -1, self.dim)
|
|
x = x + img_x
|
|
x = self.o(x)
|
|
|
|
return x
|
|
|
|
|
|
WAN_CROSSATTENTION_CLASSES = {
|
|
't2v_cross_attn': WanT2VCrossAttention,
|
|
'i2v_cross_attn': WanI2VCrossAttention,
|
|
}
|
|
|
|
|
|
class WanAttentionBlock(nn.Module):
|
|
"""Transformer block with self-attention, cross-attention, and FFN."""
|
|
|
|
def __init__(self, cross_attn_type: str, dim: int, ffn_dim: int, num_heads: int,
|
|
window_size: Tuple[int, int] = (-1, -1), qk_norm: bool = True,
|
|
cross_attn_norm: bool = False, eps: float = 1e-6):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.ffn_dim = ffn_dim
|
|
self.num_heads = num_heads
|
|
self.window_size = window_size
|
|
self.qk_norm = qk_norm
|
|
self.cross_attn_norm = cross_attn_norm
|
|
self.eps = eps
|
|
|
|
# Layers
|
|
self.norm1 = WanLayerNorm(dim, eps)
|
|
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
|
|
|
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
|
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
|
|
dim, num_heads, (-1, -1), qk_norm, eps)
|
|
|
|
self.norm2 = WanLayerNorm(dim, eps)
|
|
|
|
# FFN - use a list instead of Sequential to match PyTorch exactly!
|
|
self.ffn = [
|
|
nn.Linear(dim, ffn_dim),
|
|
nn.GELU(),
|
|
nn.Linear(ffn_dim, dim)
|
|
]
|
|
|
|
# Modulation parameters
|
|
self.modulation = mx.random.normal((1, 6, dim)) / math.sqrt(dim)
|
|
|
|
def __call__(self, x: mx.array, e: mx.array, seq_lens: mx.array,
|
|
grid_sizes: mx.array, freqs: mx.array, context: mx.array,
|
|
context_lens: Optional[mx.array]) -> mx.array:
|
|
# Apply modulation
|
|
e = (self.modulation + e).astype(mx.float32)
|
|
e_chunks = [mx.squeeze(chunk, axis=1) for chunk in mx.split(e, 6, axis=1)]
|
|
|
|
# Self-attention with modulation
|
|
y = self.norm1(x).astype(mx.float32)
|
|
y = y * (1 + e_chunks[1]) + e_chunks[0]
|
|
y = self.self_attn(y, seq_lens, grid_sizes, freqs)
|
|
x = x + y * e_chunks[2]
|
|
|
|
# Cross-attention
|
|
if self.cross_attn_norm and isinstance(self.norm3, WanLayerNorm):
|
|
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
|
else:
|
|
x = x + self.cross_attn(x, context, context_lens)
|
|
|
|
# FFN with modulation
|
|
y = self.norm2(x).astype(mx.float32)
|
|
y = y * (1 + e_chunks[4]) + e_chunks[3]
|
|
|
|
# Apply FFN layers manually
|
|
y = self.ffn[0](y) # Linear
|
|
y = self.ffn[1](y) # GELU
|
|
y = self.ffn[2](y) # Linear
|
|
|
|
x = x + y * e_chunks[5]
|
|
|
|
return x
|
|
|
|
|
|
class Head(nn.Module):
|
|
"""Output head for final projection."""
|
|
|
|
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int],
|
|
eps: float = 1e-6):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.out_dim = out_dim
|
|
self.patch_size = patch_size
|
|
self.eps = eps
|
|
|
|
# Output projection
|
|
out_features = int(np.prod(patch_size)) * out_dim
|
|
self.norm = WanLayerNorm(dim, eps)
|
|
self.head = nn.Linear(dim, out_features)
|
|
|
|
# Modulation
|
|
self.modulation = mx.random.normal((1, 2, dim)) / math.sqrt(dim)
|
|
|
|
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
|
# Apply modulation
|
|
e = (self.modulation + mx.expand_dims(e, 1)).astype(mx.float32)
|
|
e_chunks = mx.split(e, 2, axis=1)
|
|
|
|
# Apply normalization and projection with modulation
|
|
x = self.norm(x) * (1 + e_chunks[1]) + e_chunks[0]
|
|
x = self.head(x)
|
|
|
|
return x
|
|
|
|
|
|
class MLPProj(nn.Module):
|
|
"""MLP projection for image embeddings."""
|
|
|
|
def __init__(self, in_dim: int, out_dim: int):
|
|
super().__init__()
|
|
|
|
# Use a list to match PyTorch Sequential indexing
|
|
self.proj = [
|
|
nn.LayerNorm(in_dim),
|
|
nn.Linear(in_dim, in_dim),
|
|
nn.GELU(),
|
|
nn.Linear(in_dim, out_dim),
|
|
nn.LayerNorm(out_dim)
|
|
]
|
|
|
|
def __call__(self, image_embeds: mx.array) -> mx.array:
|
|
x = image_embeds
|
|
for layer in self.proj:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class WanModel(nn.Module):
|
|
"""
|
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
|
MLX implementation - True 1:1 port from PyTorch.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_type: str = 't2v',
|
|
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
|
text_len: int = 512,
|
|
in_dim: int = 16,
|
|
dim: int = 2048,
|
|
ffn_dim: int = 8192,
|
|
freq_dim: int = 256,
|
|
text_dim: int = 4096,
|
|
out_dim: int = 16,
|
|
num_heads: int = 16,
|
|
num_layers: int = 32,
|
|
window_size: Tuple[int, int] = (-1, -1),
|
|
qk_norm: bool = True,
|
|
cross_attn_norm: bool = True,
|
|
eps: float = 1e-6
|
|
):
|
|
super().__init__()
|
|
|
|
assert model_type in ['t2v', 'i2v']
|
|
self.model_type = model_type
|
|
|
|
# Store configuration
|
|
self.patch_size = patch_size
|
|
self.text_len = text_len
|
|
self.in_dim = in_dim
|
|
self.dim = dim
|
|
self.ffn_dim = ffn_dim
|
|
self.freq_dim = freq_dim
|
|
self.text_dim = text_dim
|
|
self.out_dim = out_dim
|
|
self.num_heads = num_heads
|
|
self.num_layers = num_layers
|
|
self.window_size = window_size
|
|
self.qk_norm = qk_norm
|
|
self.cross_attn_norm = cross_attn_norm
|
|
self.eps = eps
|
|
|
|
# Embeddings
|
|
self.patch_embedding = nn.Conv3d(
|
|
in_dim, dim,
|
|
kernel_size=patch_size,
|
|
stride=patch_size
|
|
)
|
|
|
|
# Use lists instead of Sequential to match PyTorch!
|
|
self.text_embedding = [
|
|
nn.Linear(text_dim, dim),
|
|
nn.GELU(),
|
|
nn.Linear(dim, dim)
|
|
]
|
|
|
|
self.time_embedding = [
|
|
nn.Linear(freq_dim, dim),
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim)
|
|
]
|
|
|
|
self.time_projection = [
|
|
nn.SiLU(),
|
|
nn.Linear(dim, dim * 6)
|
|
]
|
|
|
|
# Transformer blocks
|
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
|
self.blocks = [
|
|
WanAttentionBlock(
|
|
cross_attn_type, dim, ffn_dim, num_heads,
|
|
window_size, qk_norm, cross_attn_norm, eps
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
|
|
# Output head
|
|
self.head = Head(dim, out_dim, patch_size, eps)
|
|
|
|
# Precompute RoPE frequencies
|
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
|
d = dim // num_heads
|
|
self.freqs = mx.concatenate([
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6))
|
|
], axis=1)
|
|
|
|
# Image embedding for i2v
|
|
if model_type == 'i2v':
|
|
self.img_emb = MLPProj(1280, dim)
|
|
|
|
# Initialize weights
|
|
self.init_weights()
|
|
|
|
def __call__(
|
|
self,
|
|
x: List[mx.array],
|
|
t: mx.array,
|
|
context: List[mx.array],
|
|
seq_len: int,
|
|
clip_fea: Optional[mx.array] = None,
|
|
y: Optional[List[mx.array]] = None
|
|
) -> List[mx.array]:
|
|
"""
|
|
Forward pass through the diffusion model.
|
|
|
|
Args:
|
|
x: List of input video tensors [C_in, F, H, W]
|
|
t: Diffusion timesteps [B]
|
|
context: List of text embeddings [L, C]
|
|
seq_len: Maximum sequence length
|
|
clip_fea: CLIP image features for i2v mode
|
|
y: Conditional video inputs for i2v mode
|
|
|
|
Returns:
|
|
List of denoised video tensors [C_out, F, H/8, W/8]
|
|
"""
|
|
if self.model_type == 'i2v':
|
|
assert clip_fea is not None and y is not None
|
|
|
|
# Concatenate conditional inputs if provided
|
|
if y is not None:
|
|
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
|
|
|
|
# Patch embedding
|
|
x = [mx.transpose(mx.expand_dims(u, 0), (0, 2, 3, 4, 1)) for u in x]
|
|
x = [self.patch_embedding(u) for u in x]
|
|
# Transpose back from MLX format (N, D, H, W, C) to (N, C, D, H, W) for the rest of the model
|
|
x = [mx.transpose(u, (0, 4, 1, 2, 3)) for u in x]
|
|
grid_sizes = mx.array([[u.shape[2], u.shape[3], u.shape[4]] for u in x])
|
|
|
|
# Flatten spatial dimensions
|
|
x = [mx.transpose(u.reshape(u.shape[0], u.shape[1], -1), (0, 2, 1)) for u in x]
|
|
seq_lens = mx.array([u.shape[1] for u in x])
|
|
|
|
# Pad sequences to max length
|
|
x_padded = []
|
|
for u in x:
|
|
if u.shape[1] < seq_len:
|
|
padding = mx.zeros((1, seq_len - u.shape[1], u.shape[2]))
|
|
u = mx.concatenate([u, padding], axis=1)
|
|
x_padded.append(u)
|
|
x = mx.concatenate(x_padded, axis=0)
|
|
|
|
# Time embeddings - apply layers manually
|
|
e = sinusoidal_embedding_1d(self.freq_dim, t).astype(mx.float32)
|
|
e = self.time_embedding[0](e) # Linear
|
|
e = self.time_embedding[1](e) # SiLU
|
|
e = self.time_embedding[2](e) # Linear
|
|
|
|
# Time projection
|
|
e = self.time_projection[0](e) # SiLU
|
|
e0 = self.time_projection[1](e).reshape(-1, 6, self.dim) # Linear
|
|
|
|
# Process context
|
|
context_lens = None
|
|
context_padded = []
|
|
for u in context:
|
|
if u.shape[0] < self.text_len:
|
|
padding = mx.zeros((self.text_len - u.shape[0], u.shape[1]))
|
|
u = mx.concatenate([u, padding], axis=0)
|
|
context_padded.append(u)
|
|
context = mx.stack(context_padded)
|
|
|
|
# Apply text embedding layers manually
|
|
context = self.text_embedding[0](context) # Linear
|
|
context = self.text_embedding[1](context) # GELU
|
|
context = self.text_embedding[2](context) # Linear
|
|
|
|
# Add image embeddings for i2v
|
|
if clip_fea is not None:
|
|
context_clip = self.img_emb(clip_fea)
|
|
context = mx.concatenate([context_clip, context], axis=1)
|
|
|
|
# Apply transformer blocks
|
|
for block in self.blocks:
|
|
x = block(
|
|
x, e0, seq_lens, grid_sizes, self.freqs,
|
|
context, context_lens
|
|
)
|
|
|
|
# Apply output head
|
|
x = self.head(x, e)
|
|
|
|
# Unpatchify
|
|
x = self.unpatchify(x, grid_sizes)
|
|
|
|
return [u.astype(mx.float32) for u in x]
|
|
|
|
def unpatchify(self, x: mx.array, grid_sizes: mx.array) -> List[mx.array]:
|
|
"""Reconstruct video tensors from patch embeddings."""
|
|
c = self.out_dim
|
|
out = []
|
|
|
|
for i in range(grid_sizes.shape[0]):
|
|
f, h, w = int(grid_sizes[i, 0]), int(grid_sizes[i, 1]), int(grid_sizes[i, 2])
|
|
seq_len = f * h * w
|
|
|
|
# Extract relevant sequence
|
|
u = x[i, :seq_len]
|
|
|
|
# Reshape to grid with patches
|
|
pf, ph, pw = self.patch_size
|
|
u = u.reshape(f, h, w, pf, ph, pw, c)
|
|
|
|
# Rearrange dimensions
|
|
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
|
|
|
|
# Combine patches
|
|
u = u.reshape(c, f * pf, h * ph, w * pw)
|
|
|
|
out.append(u)
|
|
|
|
return out
|
|
|
|
def init_weights(self):
|
|
"""Initialize model parameters using Xavier/He initialization."""
|
|
# Note: MLX doesn't have nn.init like PyTorch, so we manually initialize
|
|
|
|
# Helper function for Xavier uniform initialization
|
|
def xavier_uniform(shape):
|
|
bound = mx.sqrt(6.0 / (shape[0] + shape[1]))
|
|
return mx.random.uniform(low=-bound, high=bound, shape=shape)
|
|
|
|
# Initialize linear layers in blocks
|
|
for block in self.blocks:
|
|
# Self attention
|
|
block.self_attn.q.weight = xavier_uniform(block.self_attn.q.weight.shape)
|
|
block.self_attn.k.weight = xavier_uniform(block.self_attn.k.weight.shape)
|
|
block.self_attn.v.weight = xavier_uniform(block.self_attn.v.weight.shape)
|
|
block.self_attn.o.weight = xavier_uniform(block.self_attn.o.weight.shape)
|
|
|
|
# Cross attention
|
|
block.cross_attn.q.weight = xavier_uniform(block.cross_attn.q.weight.shape)
|
|
block.cross_attn.k.weight = xavier_uniform(block.cross_attn.k.weight.shape)
|
|
block.cross_attn.v.weight = xavier_uniform(block.cross_attn.v.weight.shape)
|
|
block.cross_attn.o.weight = xavier_uniform(block.cross_attn.o.weight.shape)
|
|
|
|
# FFN layers - now it's a list!
|
|
block.ffn[0].weight = xavier_uniform(block.ffn[0].weight.shape)
|
|
block.ffn[2].weight = xavier_uniform(block.ffn[2].weight.shape)
|
|
|
|
# Modulation
|
|
block.modulation = mx.random.normal(
|
|
shape=(1, 6, self.dim),
|
|
scale=1.0 / math.sqrt(self.dim)
|
|
)
|
|
|
|
# Special initialization for embeddings
|
|
# Patch embedding - Xavier uniform
|
|
weight_shape = self.patch_embedding.weight.shape
|
|
fan_in = weight_shape[1] * np.prod(self.patch_size)
|
|
fan_out = weight_shape[0]
|
|
bound = mx.sqrt(6.0 / (fan_in + fan_out))
|
|
self.patch_embedding.weight = mx.random.uniform(
|
|
low=-bound,
|
|
high=bound,
|
|
shape=weight_shape
|
|
)
|
|
|
|
# Text embedding - normal distribution with std=0.02
|
|
self.text_embedding[0].weight = mx.random.normal(shape=self.text_embedding[0].weight.shape, scale=0.02)
|
|
self.text_embedding[2].weight = mx.random.normal(shape=self.text_embedding[2].weight.shape, scale=0.02)
|
|
|
|
# Time embedding - normal distribution with std=0.02
|
|
self.time_embedding[0].weight = mx.random.normal(shape=self.time_embedding[0].weight.shape, scale=0.02)
|
|
self.time_embedding[2].weight = mx.random.normal(shape=self.time_embedding[2].weight.shape, scale=0.02)
|
|
|
|
# Output head - initialize to zeros
|
|
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
|
|
|
|
# Head modulation
|
|
self.head.modulation = mx.random.normal(
|
|
shape=(1, 2, self.dim),
|
|
scale=1.0 / math.sqrt(self.dim)
|
|
)
|
|
|
|
# Initialize i2v specific layers if present
|
|
if self.model_type == 'i2v':
|
|
for i in [1, 3]: # Linear layers in the proj list
|
|
if isinstance(self.img_emb.proj[i], nn.Linear):
|
|
self.img_emb.proj[i].weight = xavier_uniform(self.img_emb.proj[i].weight.shape) |