mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
660 lines
22 KiB
Python
660 lines
22 KiB
Python
![]() |
# MLX implementation of model.py
|
||
|
import math
|
||
|
from typing import List, Tuple, Optional
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import mlx.nn as nn
|
||
|
import numpy as np
|
||
|
|
||
|
__all__ = ['WanModel']
|
||
|
|
||
|
|
||
|
def sinusoidal_embedding_1d(dim, position):
|
||
|
# preprocess
|
||
|
assert dim % 2 == 0
|
||
|
half = dim // 2
|
||
|
position = position.astype(mx.float32)
|
||
|
|
||
|
# calculation
|
||
|
arange_vals = mx.arange(half).astype(mx.float32)
|
||
|
div_term = mx.power(10000, -arange_vals / half)
|
||
|
sinusoid = position[:, None] @ div_term[None, :]
|
||
|
x = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def rope_params(max_seq_len, dim, theta=10000):
|
||
|
assert dim % 2 == 0
|
||
|
positions = mx.arange(max_seq_len).astype(mx.float32)
|
||
|
freqs = mx.arange(0, dim, 2).astype(mx.float32) / dim
|
||
|
freqs = 1.0 / mx.power(theta, freqs)
|
||
|
angles = positions[:, None] @ freqs[None, :]
|
||
|
# Store as [max_seq_len, dim//2, 2] where last dimension is [real, imag]
|
||
|
freqs_complex = mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1)
|
||
|
return freqs_complex
|
||
|
|
||
|
|
||
|
def rope_apply(x, grid_sizes, freqs):
|
||
|
n, c = x.shape[2], x.shape[3] // 2
|
||
|
|
||
|
# split freqs based on dimension allocation
|
||
|
split_sizes = [c - 2 * (c // 3), c // 3, c // 3]
|
||
|
freqs_splits = []
|
||
|
start = 0
|
||
|
for size in split_sizes:
|
||
|
freqs_splits.append(freqs[:, start:start+size, :])
|
||
|
start += size
|
||
|
|
||
|
# loop over samples
|
||
|
output = []
|
||
|
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||
|
seq_len = f * h * w
|
||
|
|
||
|
# reshape x_i to complex representation
|
||
|
x_i = x[i, :seq_len].reshape(seq_len, n, c, 2)
|
||
|
|
||
|
# precompute frequency multipliers for each dimension
|
||
|
freqs_f = freqs_splits[0][:f].reshape(f, 1, 1, -1, 2)
|
||
|
freqs_f = mx.tile(freqs_f, (1, h, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||
|
|
||
|
freqs_h = freqs_splits[1][:h].reshape(1, h, 1, -1, 2)
|
||
|
freqs_h = mx.tile(freqs_h, (f, 1, w, 1, 1)).reshape(f * h * w, -1, 2)
|
||
|
|
||
|
freqs_w = freqs_splits[2][:w].reshape(1, 1, w, -1, 2)
|
||
|
freqs_w = mx.tile(freqs_w, (f, h, 1, 1, 1)).reshape(f * h * w, -1, 2)
|
||
|
|
||
|
# Concatenate frequency components
|
||
|
freqs_i = mx.concatenate([freqs_f, freqs_h, freqs_w], axis=1)
|
||
|
freqs_i = freqs_i[:seq_len].reshape(seq_len, 1, c, 2)
|
||
|
|
||
|
# apply rotary embedding (complex multiplication)
|
||
|
x_real = x_i[..., 0]
|
||
|
x_imag = x_i[..., 1]
|
||
|
freqs_real = freqs_i[..., 0]
|
||
|
freqs_imag = freqs_i[..., 1]
|
||
|
|
||
|
out_real = x_real * freqs_real - x_imag * freqs_imag
|
||
|
out_imag = x_real * freqs_imag + x_imag * freqs_real
|
||
|
|
||
|
x_i = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, -1)
|
||
|
|
||
|
# Handle remaining sequence
|
||
|
if x.shape[1] > seq_len:
|
||
|
x_i = mx.concatenate([x_i, x[i, seq_len:]], axis=0)
|
||
|
|
||
|
output.append(x_i)
|
||
|
|
||
|
return mx.stack(output)
|
||
|
|
||
|
|
||
|
class WanRMSNorm(nn.Module):
|
||
|
|
||
|
def __init__(self, dim, eps=1e-5):
|
||
|
super().__init__()
|
||
|
self.dim = dim
|
||
|
self.eps = eps
|
||
|
self.weight = mx.ones((dim,))
|
||
|
|
||
|
def __call__(self, x):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L, C]
|
||
|
"""
|
||
|
return self._norm(x) * self.weight
|
||
|
|
||
|
def _norm(self, x):
|
||
|
return x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
|
||
|
|
||
|
|
||
|
class WanLayerNorm(nn.LayerNorm):
|
||
|
|
||
|
def __init__(self, dim, eps=1e-6, affine=False):
|
||
|
super().__init__(dims=dim, eps=eps, affine=affine)
|
||
|
|
||
|
def __call__(self, x):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L, C]
|
||
|
"""
|
||
|
return super().__call__(x)
|
||
|
|
||
|
|
||
|
def mlx_attention(
|
||
|
q: mx.array,
|
||
|
k: mx.array,
|
||
|
v: mx.array,
|
||
|
q_lens: Optional[mx.array] = None,
|
||
|
k_lens: Optional[mx.array] = None,
|
||
|
dropout_p: float = 0.,
|
||
|
softmax_scale: Optional[float] = None,
|
||
|
q_scale: Optional[float] = None,
|
||
|
causal: bool = False,
|
||
|
window_size: Tuple[int, int] = (-1, -1),
|
||
|
deterministic: bool = False,
|
||
|
dtype: Optional[type] = None,
|
||
|
) -> mx.array:
|
||
|
# Get shapes
|
||
|
b, lq, n, d = q.shape
|
||
|
_, lk, _, _ = k.shape
|
||
|
|
||
|
# Scale queries if needed
|
||
|
if q_scale is not None:
|
||
|
q = q * q_scale
|
||
|
|
||
|
# Compute attention scores
|
||
|
q = q.transpose(0, 2, 1, 3) # [b, n, lq, d]
|
||
|
k = k.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||
|
v = v.transpose(0, 2, 1, 3) # [b, n, lk, d]
|
||
|
|
||
|
# Compute attention scores
|
||
|
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) # [b, n, lq, lk]
|
||
|
|
||
|
# Apply softmax scale if provided
|
||
|
if softmax_scale is not None:
|
||
|
scores = scores * softmax_scale
|
||
|
else:
|
||
|
# Default scaling by sqrt(d)
|
||
|
scores = scores / mx.sqrt(mx.array(d, dtype=scores.dtype))
|
||
|
|
||
|
# Create attention mask
|
||
|
attn_mask = None
|
||
|
|
||
|
# Apply window size masking if specified
|
||
|
if window_size != (-1, -1):
|
||
|
left_window, right_window = window_size
|
||
|
window_mask = mx.zeros((lq, lk))
|
||
|
for i in range(lq):
|
||
|
start = max(0, i - left_window)
|
||
|
end = min(lk, i + right_window + 1)
|
||
|
window_mask[i, start:end] = 1
|
||
|
attn_mask = window_mask
|
||
|
|
||
|
# Apply causal masking if needed
|
||
|
if causal:
|
||
|
causal_mask = mx.tril(mx.ones((lq, lk)), k=0)
|
||
|
if attn_mask is None:
|
||
|
attn_mask = causal_mask
|
||
|
else:
|
||
|
attn_mask = mx.logical_and(attn_mask, causal_mask)
|
||
|
|
||
|
# Apply attention mask if present
|
||
|
if attn_mask is not None:
|
||
|
attn_mask = attn_mask.astype(scores.dtype)
|
||
|
scores = scores * attn_mask + (1 - attn_mask) * -1e4
|
||
|
|
||
|
# Apply attention mask if lengths are provided
|
||
|
if q_lens is not None or k_lens is not None:
|
||
|
if q_lens is not None:
|
||
|
mask = mx.arange(lq)[None, :] < q_lens[:, None]
|
||
|
mask = mask.astype(scores.dtype)
|
||
|
scores = scores * mask[:, None, :, None] + (1 - mask[:, None, :, None]) * -1e4
|
||
|
if k_lens is not None:
|
||
|
mask = mx.arange(lk)[None, :] < k_lens[:, None]
|
||
|
mask = mask.astype(scores.dtype)
|
||
|
scores = scores * mask[:, None, None, :] + (1 - mask[:, None, None, :]) * -1e4
|
||
|
|
||
|
# Apply softmax
|
||
|
max_scores = mx.max(scores, axis=-1, keepdims=True)
|
||
|
scores = scores - max_scores
|
||
|
exp_scores = mx.exp(scores)
|
||
|
sum_exp = mx.sum(exp_scores, axis=-1, keepdims=True)
|
||
|
attn = exp_scores / (sum_exp + 1e-6)
|
||
|
|
||
|
# Apply dropout if needed
|
||
|
if dropout_p > 0 and not deterministic:
|
||
|
raise NotImplementedError("Dropout not implemented in MLX version")
|
||
|
|
||
|
# Compute output
|
||
|
out = mx.matmul(attn, v) # [b, n, lq, d]
|
||
|
out = out.transpose(0, 2, 1, 3) # [b, lq, n, d]
|
||
|
|
||
|
return out
|
||
|
|
||
|
class WanSelfAttention(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
dim,
|
||
|
num_heads,
|
||
|
window_size=(-1, -1),
|
||
|
qk_norm=True,
|
||
|
eps=1e-6):
|
||
|
assert dim % num_heads == 0
|
||
|
super().__init__()
|
||
|
self.dim = dim
|
||
|
self.num_heads = num_heads
|
||
|
self.head_dim = dim // num_heads
|
||
|
self.window_size = window_size
|
||
|
self.qk_norm = qk_norm
|
||
|
self.eps = eps
|
||
|
|
||
|
# layers
|
||
|
self.q = nn.Linear(dim, dim)
|
||
|
self.k = nn.Linear(dim, dim)
|
||
|
self.v = nn.Linear(dim, dim)
|
||
|
self.o = nn.Linear(dim, dim)
|
||
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||
|
|
||
|
def __call__(self, x, seq_lens, grid_sizes, freqs):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L, C]
|
||
|
seq_lens(Array): Shape [B]
|
||
|
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||
|
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||
|
"""
|
||
|
b, s, n, d = x.shape[0], x.shape[1], self.num_heads, self.head_dim
|
||
|
|
||
|
# query, key, value function
|
||
|
q = self.norm_q(self.q(x)).reshape(b, s, n, d)
|
||
|
k = self.norm_k(self.k(x)).reshape(b, s, n, d)
|
||
|
v = self.v(x).reshape(b, s, n, d)
|
||
|
|
||
|
x = mlx_attention(
|
||
|
q=rope_apply(q, grid_sizes, freqs),
|
||
|
k=rope_apply(k, grid_sizes, freqs),
|
||
|
v=v,
|
||
|
k_lens=seq_lens,
|
||
|
window_size=self.window_size)
|
||
|
|
||
|
# output
|
||
|
x = x.reshape(b, s, -1)
|
||
|
x = self.o(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class WanCrossAttention(WanSelfAttention):
|
||
|
|
||
|
def __call__(self, x, context, context_lens):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L1, C]
|
||
|
context(Array): Shape [B, L2, C]
|
||
|
context_lens(Array): Shape [B]
|
||
|
"""
|
||
|
b, n, d = x.shape[0], self.num_heads, self.head_dim
|
||
|
|
||
|
# compute query, key, value
|
||
|
q = self.norm_q(self.q(x)).reshape(b, -1, n, d)
|
||
|
k = self.norm_k(self.k(context)).reshape(b, -1, n, d)
|
||
|
v = self.v(context).reshape(b, -1, n, d)
|
||
|
|
||
|
# compute attention
|
||
|
x = mlx_attention(q, k, v, k_lens=context_lens)
|
||
|
|
||
|
# output
|
||
|
x = x.reshape(b, -1, self.dim)
|
||
|
x = self.o(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class WanAttentionBlock(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
dim,
|
||
|
ffn_dim,
|
||
|
num_heads,
|
||
|
window_size=(-1, -1),
|
||
|
qk_norm=True,
|
||
|
cross_attn_norm=False,
|
||
|
eps=1e-6):
|
||
|
super().__init__()
|
||
|
self.dim = dim
|
||
|
self.ffn_dim = ffn_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.window_size = window_size
|
||
|
self.qk_norm = qk_norm
|
||
|
self.cross_attn_norm = cross_attn_norm
|
||
|
self.eps = eps
|
||
|
|
||
|
# layers
|
||
|
self.norm1 = WanLayerNorm(dim, eps)
|
||
|
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
||
|
eps)
|
||
|
self.norm3 = WanLayerNorm(
|
||
|
dim, eps,
|
||
|
affine=True) if cross_attn_norm else nn.Identity()
|
||
|
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
|
||
|
eps)
|
||
|
self.norm2 = WanLayerNorm(dim, eps)
|
||
|
self.ffn = nn.Sequential(
|
||
|
nn.Linear(dim, ffn_dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(ffn_dim, dim))
|
||
|
|
||
|
# modulation
|
||
|
self.modulation = mx.random.normal((1, 6, dim)) / dim**0.5
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
x,
|
||
|
e,
|
||
|
seq_lens,
|
||
|
grid_sizes,
|
||
|
freqs,
|
||
|
context,
|
||
|
context_lens,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L, C]
|
||
|
e(Array): Shape [B, L1, 6, C]
|
||
|
seq_lens(Array): Shape [B], length of each sequence in batch
|
||
|
grid_sizes(Array): Shape [B, 3], the second dimension contains (F, H, W)
|
||
|
freqs(Array): Rope freqs, shape [1024, C / num_heads / 2, 2]
|
||
|
"""
|
||
|
e = mx.split(self.modulation + e, 6, axis=2)
|
||
|
|
||
|
# self-attention
|
||
|
y = self.self_attn(
|
||
|
self.norm1(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2),
|
||
|
seq_lens, grid_sizes, freqs)
|
||
|
x = x + y * mx.squeeze(e[2], axis=2)
|
||
|
|
||
|
# cross-attention & ffn function
|
||
|
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||
|
y = self.ffn(
|
||
|
self.norm2(x) * (1 + mx.squeeze(e[4], axis=2)) + mx.squeeze(e[3], axis=2))
|
||
|
x = x + y * mx.squeeze(e[5], axis=2)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Head(nn.Module):
|
||
|
|
||
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||
|
super().__init__()
|
||
|
self.dim = dim
|
||
|
self.out_dim = out_dim
|
||
|
self.patch_size = patch_size
|
||
|
self.eps = eps
|
||
|
|
||
|
# layers
|
||
|
out_dim = math.prod(patch_size) * out_dim
|
||
|
self.norm = WanLayerNorm(dim, eps)
|
||
|
self.head = nn.Linear(dim, out_dim)
|
||
|
|
||
|
# modulation
|
||
|
self.modulation = mx.random.normal((1, 2, dim)) / dim**0.5
|
||
|
|
||
|
def __call__(self, x, e):
|
||
|
"""
|
||
|
Args:
|
||
|
x(Array): Shape [B, L1, C]
|
||
|
e(Array): Shape [B, L1, C]
|
||
|
"""
|
||
|
e = mx.split(self.modulation + mx.expand_dims(e, axis=2), 2, axis=2)
|
||
|
x = self.head(
|
||
|
self.norm(x) * (1 + mx.squeeze(e[1], axis=2)) + mx.squeeze(e[0], axis=2))
|
||
|
return x
|
||
|
|
||
|
|
||
|
class WanModel(nn.Module):
|
||
|
"""
|
||
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
model_type='t2v',
|
||
|
patch_size=(1, 2, 2),
|
||
|
text_len=512,
|
||
|
in_dim=16,
|
||
|
dim=2048,
|
||
|
ffn_dim=8192,
|
||
|
freq_dim=256,
|
||
|
text_dim=4096,
|
||
|
out_dim=16,
|
||
|
num_heads=16,
|
||
|
num_layers=32,
|
||
|
window_size=(-1, -1),
|
||
|
qk_norm=True,
|
||
|
cross_attn_norm=True,
|
||
|
eps=1e-6):
|
||
|
"""
|
||
|
Initialize the diffusion model backbone.
|
||
|
|
||
|
Args:
|
||
|
model_type (`str`, *optional*, defaults to 't2v'):
|
||
|
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||
|
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||
|
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||
|
text_len (`int`, *optional*, defaults to 512):
|
||
|
Fixed length for text embeddings
|
||
|
in_dim (`int`, *optional*, defaults to 16):
|
||
|
Input video channels (C_in)
|
||
|
dim (`int`, *optional*, defaults to 2048):
|
||
|
Hidden dimension of the transformer
|
||
|
ffn_dim (`int`, *optional*, defaults to 8192):
|
||
|
Intermediate dimension in feed-forward network
|
||
|
freq_dim (`int`, *optional*, defaults to 256):
|
||
|
Dimension for sinusoidal time embeddings
|
||
|
text_dim (`int`, *optional*, defaults to 4096):
|
||
|
Input dimension for text embeddings
|
||
|
out_dim (`int`, *optional*, defaults to 16):
|
||
|
Output video channels (C_out)
|
||
|
num_heads (`int`, *optional*, defaults to 16):
|
||
|
Number of attention heads
|
||
|
num_layers (`int`, *optional*, defaults to 32):
|
||
|
Number of transformer blocks
|
||
|
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||
|
Window size for local attention (-1 indicates global attention)
|
||
|
qk_norm (`bool`, *optional*, defaults to True):
|
||
|
Enable query/key normalization
|
||
|
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||
|
Enable cross-attention normalization
|
||
|
eps (`float`, *optional*, defaults to 1e-6):
|
||
|
Epsilon value for normalization layers
|
||
|
"""
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
assert model_type in ['t2v', 'i2v', 'ti2v']
|
||
|
self.model_type = model_type
|
||
|
|
||
|
self.patch_size = patch_size
|
||
|
self.text_len = text_len
|
||
|
self.in_dim = in_dim
|
||
|
self.dim = dim
|
||
|
self.ffn_dim = ffn_dim
|
||
|
self.freq_dim = freq_dim
|
||
|
self.text_dim = text_dim
|
||
|
self.out_dim = out_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.num_layers = num_layers
|
||
|
self.window_size = window_size
|
||
|
self.qk_norm = qk_norm
|
||
|
self.cross_attn_norm = cross_attn_norm
|
||
|
self.eps = eps
|
||
|
|
||
|
# embeddings
|
||
|
self.patch_embedding = nn.Conv3d(
|
||
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||
|
self.text_embedding = nn.Sequential(
|
||
|
nn.Linear(text_dim, dim),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(dim, dim))
|
||
|
|
||
|
self.time_embedding = nn.Sequential(
|
||
|
nn.Linear(freq_dim, dim),
|
||
|
nn.SiLU(),
|
||
|
nn.Linear(dim, dim))
|
||
|
self.time_projection = nn.Sequential(
|
||
|
nn.SiLU(),
|
||
|
nn.Linear(dim, dim * 6))
|
||
|
|
||
|
# blocks
|
||
|
self.blocks = [
|
||
|
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
|
||
|
cross_attn_norm, eps) for _ in range(num_layers)
|
||
|
]
|
||
|
|
||
|
# head
|
||
|
self.head = Head(dim, out_dim, patch_size, eps)
|
||
|
|
||
|
# buffers
|
||
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||
|
d = dim // num_heads
|
||
|
self.freqs = mx.concatenate([
|
||
|
rope_params(1024, d - 4 * (d // 6)),
|
||
|
rope_params(1024, 2 * (d // 6)),
|
||
|
rope_params(1024, 2 * (d // 6))
|
||
|
], axis=1)
|
||
|
|
||
|
# initialize weights
|
||
|
self.init_weights()
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
x,
|
||
|
t,
|
||
|
context,
|
||
|
seq_len,
|
||
|
y=None,
|
||
|
):
|
||
|
"""
|
||
|
Forward pass through the diffusion model
|
||
|
|
||
|
Args:
|
||
|
x (List[Array]):
|
||
|
List of input video tensors, each with shape [C_in, F, H, W]
|
||
|
t (Array):
|
||
|
Diffusion timesteps tensor of shape [B]
|
||
|
context (List[Array]):
|
||
|
List of text embeddings each with shape [L, C]
|
||
|
seq_len (`int`):
|
||
|
Maximum sequence length for positional encoding
|
||
|
y (List[Array], *optional*):
|
||
|
Conditional video inputs for image-to-video mode, same shape as x
|
||
|
|
||
|
Returns:
|
||
|
List[Array]:
|
||
|
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||
|
"""
|
||
|
if self.model_type == 'i2v':
|
||
|
assert y is not None
|
||
|
|
||
|
if y is not None:
|
||
|
x = [mx.concatenate([u, v], axis=0) for u, v in zip(x, y)]
|
||
|
|
||
|
# embeddings
|
||
|
x = [self.patch_embedding(mx.expand_dims(mx.transpose(u, (1, 2, 3, 0)), axis=0)) for u in x]
|
||
|
|
||
|
grid_sizes = mx.stack(
|
||
|
[mx.array(u.shape[1:4], dtype=mx.int32) for u in x])
|
||
|
|
||
|
|
||
|
x = [u.reshape(u.shape[0], -1, u.shape[-1]) for u in x]
|
||
|
|
||
|
seq_lens = mx.array([u.shape[1] for u in x], dtype=mx.int32)
|
||
|
assert seq_lens.max() <= seq_len
|
||
|
|
||
|
# Pad sequences
|
||
|
x_padded = []
|
||
|
for u in x:
|
||
|
pad_len = seq_len - u.shape[1]
|
||
|
if pad_len > 0:
|
||
|
padding = mx.zeros((u.shape[0], pad_len, u.shape[2]))
|
||
|
u = mx.concatenate([u, padding], axis=1)
|
||
|
x_padded.append(u)
|
||
|
x = mx.concatenate(x_padded, axis=0)
|
||
|
|
||
|
# time embeddings
|
||
|
if t.ndim == 1:
|
||
|
t = mx.broadcast_to(t[:, None], (t.shape[0], seq_len))
|
||
|
bt = t.shape[0]
|
||
|
t = t.flatten()
|
||
|
e = self.time_embedding(
|
||
|
sinusoidal_embedding_1d(self.freq_dim, t).reshape(bt, seq_len, -1))
|
||
|
e0 = self.time_projection(e).reshape(bt, seq_len, 6, self.dim)
|
||
|
|
||
|
# context
|
||
|
context_lens = None
|
||
|
context_padded = []
|
||
|
for u in context:
|
||
|
pad_len = self.text_len - u.shape[0]
|
||
|
if pad_len > 0:
|
||
|
padding = mx.zeros((pad_len, u.shape[1]))
|
||
|
u = mx.concatenate([u, padding], axis=0)
|
||
|
context_padded.append(u)
|
||
|
context = self.text_embedding(mx.stack(context_padded))
|
||
|
|
||
|
# arguments
|
||
|
kwargs = dict(
|
||
|
e=e0,
|
||
|
seq_lens=seq_lens,
|
||
|
grid_sizes=grid_sizes,
|
||
|
freqs=self.freqs,
|
||
|
context=context,
|
||
|
context_lens=context_lens)
|
||
|
|
||
|
for block in self.blocks:
|
||
|
x = block(x, **kwargs)
|
||
|
|
||
|
# head
|
||
|
x = self.head(x, e)
|
||
|
|
||
|
# unpatchify
|
||
|
x = self.unpatchify(x, grid_sizes)
|
||
|
return x
|
||
|
|
||
|
def unpatchify(self, x, grid_sizes):
|
||
|
"""
|
||
|
Reconstruct video tensors from patch embeddings.
|
||
|
|
||
|
Args:
|
||
|
x (List[Array]):
|
||
|
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||
|
grid_sizes (Array):
|
||
|
Original spatial-temporal grid dimensions before patching,
|
||
|
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||
|
|
||
|
Returns:
|
||
|
List[Array]:
|
||
|
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||
|
"""
|
||
|
|
||
|
c = self.out_dim
|
||
|
out = []
|
||
|
for i, v in enumerate(grid_sizes):
|
||
|
v = v.tolist()
|
||
|
seq_len = math.prod(v)
|
||
|
u = x[i, :seq_len].reshape(*v, *self.patch_size, c)
|
||
|
# Rearrange dimensions: (f, h, w, p, q, r, c) -> (c, f*p, h*q, w*r)
|
||
|
u = mx.transpose(u, (6, 0, 3, 1, 4, 2, 5))
|
||
|
u = u.reshape(c, v[0] * self.patch_size[0],
|
||
|
v[1] * self.patch_size[1],
|
||
|
v[2] * self.patch_size[2])
|
||
|
out.append(u)
|
||
|
return out
|
||
|
|
||
|
def init_weights(self):
|
||
|
"""
|
||
|
Initialize model parameters using Xavier initialization.
|
||
|
"""
|
||
|
# Initialize patch embedding
|
||
|
fan_in = self.in_dim * math.prod(self.patch_size)
|
||
|
fan_out = self.dim
|
||
|
std = math.sqrt(2.0 / (fan_in + fan_out))
|
||
|
self.patch_embedding.weight = mx.random.uniform(
|
||
|
low=-std, high=std, shape=self.patch_embedding.weight.shape)
|
||
|
|
||
|
# Initialize text embedding layers with normal distribution
|
||
|
text_layers = list(self.text_embedding.layers)
|
||
|
for i in [0, 2]: # First and third layers
|
||
|
layer = text_layers[i]
|
||
|
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||
|
if hasattr(layer, 'bias') and layer.bias is not None:
|
||
|
layer.bias = mx.zeros(layer.bias.shape)
|
||
|
|
||
|
# Initialize time embedding layers
|
||
|
time_layers = list(self.time_embedding.layers)
|
||
|
for i in [0, 2]: # First and third layers
|
||
|
layer = time_layers[i]
|
||
|
layer.weight = mx.random.normal(shape=layer.weight.shape) * 0.02
|
||
|
if hasattr(layer, 'bias') and layer.bias is not None:
|
||
|
layer.bias = mx.zeros(layer.bias.shape)
|
||
|
|
||
|
# Initialize output head to zeros
|
||
|
self.head.head.weight = mx.zeros(self.head.head.weight.shape)
|
||
|
if hasattr(self.head.head, 'bias') and self.head.head.bias is not None:
|
||
|
self.head.head.bias = mx.zeros(self.head.head.bias.shape)
|