mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Implement Wan2.2
This commit is contained in:
616
video/Wan2.2/wan/modules/t5.py
Normal file
616
video/Wan2.2/wan/modules/t5.py
Normal file
@@ -0,0 +1,616 @@
|
||||
# MLX implementation for t5.py
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
'T5EncoderModel',
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == mx.float16:
|
||||
# Use same clamping as PyTorch for consistency
|
||||
clamp = 65504.0 # max value for float16
|
||||
return mx.clip(x, -clamp, clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __call__(self, x):
|
||||
return 0.5 * x * (1.0 + mx.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# Match PyTorch's approach: convert to float32 for stability
|
||||
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
||||
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
||||
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
||||
# Convert back to original dtype
|
||||
if x.dtype == mx.float16:
|
||||
x_norm = x_norm.astype(mx.float16)
|
||||
return self.weight * x_norm
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, l1, _ = x.shape
|
||||
_, l2, _ = context.shape
|
||||
n, c = self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).reshape(b, l1, n, c)
|
||||
k = self.k(context).reshape(b, l2, n, c)
|
||||
v = self.v(context).reshape(b, l2, n, c)
|
||||
|
||||
# transpose for attention: [B, N, L, C]
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
||||
|
||||
# add position bias if provided
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias
|
||||
|
||||
# apply mask
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
# [B, L2] -> [B, 1, 1, L2]
|
||||
mask = mask[:, None, None, :]
|
||||
elif mask.ndim == 3:
|
||||
# [B, L1, L2] -> [B, 1, L1, L2]
|
||||
mask = mask[:, None, :, :]
|
||||
# Use very negative value that works well with float16
|
||||
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
||||
attn = mx.where(mask == 0, min_value, attn)
|
||||
|
||||
# softmax and apply attention
|
||||
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# apply attention to values
|
||||
x = mx.matmul(attn, v) # [B, N, L1, C]
|
||||
|
||||
# transpose back and reshape
|
||||
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
||||
x = x.reshape(b, l1, -1)
|
||||
|
||||
# output projection
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = GELU()
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def __call__(self, x):
|
||||
gate = self.gate_act(self.gate_proj(x))
|
||||
x = self.fc1(x) * gate
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def __call__(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False)
|
||||
|
||||
def __call__(self,
|
||||
x,
|
||||
mask=None,
|
||||
encoder_states=None,
|
||||
encoder_mask=None,
|
||||
pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.shape[1], x.shape[1])
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(
|
||||
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def __call__(self, lq, lk):
|
||||
# Create relative position matrix
|
||||
positions_q = mx.arange(lq)[:, None]
|
||||
positions_k = mx.arange(lk)[None, :]
|
||||
rel_pos = positions_k - positions_q
|
||||
|
||||
# Apply bucketing
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
|
||||
# Get embeddings
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
|
||||
# Reshape to [1, N, Lq, Lk]
|
||||
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
||||
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
||||
|
||||
return rel_pos_embeds
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
# For large positions, use log scale
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
|
||||
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
||||
|
||||
# Combine small and large position buckets
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
||||
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
if isinstance(vocab, nn.Embedding):
|
||||
self.token_embedding = vocab
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(vocab, dim)
|
||||
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = [
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.shape
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = mx.tril(mx.ones((1, s, s)))
|
||||
elif mask.ndim == 2:
|
||||
# Expand mask properly
|
||||
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.shape[1],
|
||||
x.shape[1]) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, encoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
||||
num_heads, decoder_layers, num_buckets,
|
||||
shared_pos, dropout)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_mlx_weights(module, key):
|
||||
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
||||
|
||||
def normal(key, shape, std=1.0):
|
||||
return mx.random.normal(key, shape) * std
|
||||
|
||||
if isinstance(module, T5LayerNorm):
|
||||
module.weight = mx.ones_like(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.weight = normal(key, module.weight.shape, std=1.0)
|
||||
elif isinstance(module, T5FeedForward):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3 = mx.random.split(key, 3)
|
||||
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
||||
std=module.dim_ffn**-0.5)
|
||||
elif isinstance(module, T5Attention):
|
||||
# Match PyTorch initialization
|
||||
key1, key2, key3, key4 = random.split(key, 4)
|
||||
module.q.weight = normal(key1, module.q.weight.shape,
|
||||
std=(module.dim * module.dim_attn)**-0.5)
|
||||
module.k.weight = normal(key2, module.k.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.v.weight = normal(key3, module.v.weight.shape,
|
||||
std=module.dim**-0.5)
|
||||
module.o.weight = normal(key4, module.o.weight.shape,
|
||||
std=(module.num_heads * module.dim_attn)**-0.5)
|
||||
elif isinstance(module, T5RelativeEmbedding):
|
||||
key = mx.random.split(key, 1)[0]
|
||||
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
||||
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Generic linear layer initialization
|
||||
key = mx.random.split(key, 1)[0]
|
||||
fan_in = module.weight.shape[1]
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _t5(name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs={},
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
||||
_ = kwargs.pop('decoder_layers')
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs['vocab'] = kwargs.pop('vocab_size')
|
||||
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
||||
_ = kwargs.pop('encoder_layers')
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# Initialize weights properly
|
||||
key = mx.random.key(0)
|
||||
model = init_mlx_weights(model, key)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = dict(
|
||||
vocab_size=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
encoder_layers=24,
|
||||
decoder_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.0)
|
||||
cfg.update(**kwargs)
|
||||
return _t5('umt5-xxl', **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False)
|
||||
|
||||
if checkpoint_path:
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
# Load weights - assuming MLX format checkpoint
|
||||
weights = mx.load(checkpoint_path)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
self.model = model
|
||||
|
||||
# init tokenizer
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
||||
seq_len=text_len,
|
||||
clean='whitespace')
|
||||
|
||||
def __call__(self, texts):
|
||||
# Handle single string input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
# Tokenize texts
|
||||
tokenizer_output = self.tokenizer(
|
||||
texts, return_mask=True, add_special_tokens=True)
|
||||
|
||||
# Handle different tokenizer output formats
|
||||
if isinstance(tokenizer_output, tuple):
|
||||
ids, mask = tokenizer_output
|
||||
else:
|
||||
# Assuming dict output with 'input_ids' and 'attention_mask'
|
||||
ids = tokenizer_output['input_ids']
|
||||
mask = tokenizer_output['attention_mask']
|
||||
|
||||
# Convert to MLX arrays if not already
|
||||
if not isinstance(ids, mx.array):
|
||||
ids = mx.array(ids)
|
||||
if not isinstance(mask, mx.array):
|
||||
mask = mx.array(mask)
|
||||
|
||||
# Get sequence lengths
|
||||
seq_lens = mx.sum(mask > 0, axis=1)
|
||||
|
||||
# Run encoder
|
||||
context = self.model(ids, mask)
|
||||
|
||||
# Return variable length outputs
|
||||
# Convert seq_lens to Python list for indexing
|
||||
if seq_lens.ndim == 0: # Single value
|
||||
seq_lens_list = [seq_lens.item()]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
||||
|
||||
|
||||
# Utility function to convert PyTorch checkpoint to MLX
|
||||
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
||||
"""Convert PyTorch checkpoint to MLX format"""
|
||||
import torch
|
||||
|
||||
# Load PyTorch checkpoint
|
||||
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
||||
|
||||
# Convert to numpy then to MLX
|
||||
mlx_state = {}
|
||||
for key, value in pytorch_state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Handle the key mapping if needed
|
||||
mlx_key = key
|
||||
# Convert tensor to MLX array
|
||||
mlx_state[mlx_key] = mx.array(value.numpy())
|
||||
|
||||
# Save MLX checkpoint
|
||||
mx.save(mlx_path, mlx_state)
|
||||
|
||||
return mlx_state
|
||||
Reference in New Issue
Block a user