mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 12:06:51 +08:00
616 lines
20 KiB
Python
616 lines
20 KiB
Python
# MLX implementation for t5.py
|
|
import logging
|
|
import math
|
|
from typing import Optional, Tuple, List
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
from mlx.utils import tree_unflatten
|
|
|
|
from .tokenizers import HuggingfaceTokenizer
|
|
|
|
__all__ = [
|
|
'T5Model',
|
|
'T5Encoder',
|
|
'T5Decoder',
|
|
'T5EncoderModel',
|
|
]
|
|
|
|
|
|
def fp16_clamp(x):
|
|
if x.dtype == mx.float16:
|
|
# Use same clamping as PyTorch for consistency
|
|
clamp = 65504.0 # max value for float16
|
|
return mx.clip(x, -clamp, clamp)
|
|
return x
|
|
|
|
|
|
class GELU(nn.Module):
|
|
def __call__(self, x):
|
|
return 0.5 * x * (1.0 + mx.tanh(
|
|
math.sqrt(2.0 / math.pi) * (x + 0.044715 * mx.power(x, 3.0))))
|
|
|
|
|
|
class T5LayerNorm(nn.Module):
|
|
def __init__(self, dim, eps=1e-6):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
self.weight = mx.ones((dim,))
|
|
|
|
def __call__(self, x):
|
|
# Match PyTorch's approach: convert to float32 for stability
|
|
x_float = x.astype(mx.float32) if x.dtype == mx.float16 else x
|
|
variance = mx.mean(mx.square(x_float), axis=-1, keepdims=True)
|
|
x_norm = x_float * mx.rsqrt(variance + self.eps)
|
|
# Convert back to original dtype
|
|
if x.dtype == mx.float16:
|
|
x_norm = x_norm.astype(mx.float16)
|
|
return self.weight * x_norm
|
|
|
|
|
|
class T5Attention(nn.Module):
|
|
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
|
|
assert dim_attn % num_heads == 0
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim_attn // num_heads
|
|
|
|
# layers
|
|
self.q = nn.Linear(dim, dim_attn, bias=False)
|
|
self.k = nn.Linear(dim, dim_attn, bias=False)
|
|
self.v = nn.Linear(dim, dim_attn, bias=False)
|
|
self.o = nn.Linear(dim_attn, dim, bias=False)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def __call__(self, x, context=None, mask=None, pos_bias=None):
|
|
"""
|
|
x: [B, L1, C].
|
|
context: [B, L2, C] or None.
|
|
mask: [B, L2] or [B, L1, L2] or None.
|
|
"""
|
|
# check inputs
|
|
context = x if context is None else context
|
|
b, l1, _ = x.shape
|
|
_, l2, _ = context.shape
|
|
n, c = self.num_heads, self.head_dim
|
|
|
|
# compute query, key, value
|
|
q = self.q(x).reshape(b, l1, n, c)
|
|
k = self.k(context).reshape(b, l2, n, c)
|
|
v = self.v(context).reshape(b, l2, n, c)
|
|
|
|
# transpose for attention: [B, N, L, C]
|
|
q = mx.transpose(q, (0, 2, 1, 3))
|
|
k = mx.transpose(k, (0, 2, 1, 3))
|
|
v = mx.transpose(v, (0, 2, 1, 3))
|
|
|
|
# compute attention (T5 does not use scaling)
|
|
attn = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) # [B, N, L1, L2]
|
|
|
|
# add position bias if provided
|
|
if pos_bias is not None:
|
|
attn = attn + pos_bias
|
|
|
|
# apply mask
|
|
if mask is not None:
|
|
if mask.ndim == 2:
|
|
# [B, L2] -> [B, 1, 1, L2]
|
|
mask = mask[:, None, None, :]
|
|
elif mask.ndim == 3:
|
|
# [B, L1, L2] -> [B, 1, L1, L2]
|
|
mask = mask[:, None, :, :]
|
|
# Use very negative value that works well with float16
|
|
min_value = -65504.0 if attn.dtype == mx.float16 else -1e9
|
|
attn = mx.where(mask == 0, min_value, attn)
|
|
|
|
# softmax and apply attention
|
|
attn = mx.softmax(attn.astype(mx.float32), axis=-1).astype(attn.dtype)
|
|
attn = self.dropout(attn)
|
|
|
|
# apply attention to values
|
|
x = mx.matmul(attn, v) # [B, N, L1, C]
|
|
|
|
# transpose back and reshape
|
|
x = mx.transpose(x, (0, 2, 1, 3)) # [B, L1, N, C]
|
|
x = x.reshape(b, l1, -1)
|
|
|
|
# output projection
|
|
x = self.o(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class T5FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_ffn, dropout=0.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_ffn = dim_ffn
|
|
|
|
# layers
|
|
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
|
self.gate_act = GELU()
|
|
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
|
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def __call__(self, x):
|
|
gate = self.gate_act(self.gate_proj(x))
|
|
x = self.fc1(x) * gate
|
|
x = self.dropout(x)
|
|
x = self.fc2(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class T5SelfAttention(nn.Module):
|
|
def __init__(self,
|
|
dim,
|
|
dim_attn,
|
|
dim_ffn,
|
|
num_heads,
|
|
num_buckets,
|
|
shared_pos=True,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.dim_ffn = dim_ffn
|
|
self.num_heads = num_heads
|
|
self.num_buckets = num_buckets
|
|
self.shared_pos = shared_pos
|
|
|
|
# layers
|
|
self.norm1 = T5LayerNorm(dim)
|
|
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
|
self.norm2 = T5LayerNorm(dim)
|
|
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
|
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
|
num_buckets, num_heads, bidirectional=True)
|
|
|
|
def __call__(self, x, mask=None, pos_bias=None):
|
|
e = pos_bias if self.shared_pos else self.pos_embedding(
|
|
x.shape[1], x.shape[1])
|
|
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
|
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
class T5CrossAttention(nn.Module):
|
|
def __init__(self,
|
|
dim,
|
|
dim_attn,
|
|
dim_ffn,
|
|
num_heads,
|
|
num_buckets,
|
|
shared_pos=True,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.dim_ffn = dim_ffn
|
|
self.num_heads = num_heads
|
|
self.num_buckets = num_buckets
|
|
self.shared_pos = shared_pos
|
|
|
|
# layers
|
|
self.norm1 = T5LayerNorm(dim)
|
|
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
|
self.norm2 = T5LayerNorm(dim)
|
|
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
|
self.norm3 = T5LayerNorm(dim)
|
|
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
|
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
|
num_buckets, num_heads, bidirectional=False)
|
|
|
|
def __call__(self,
|
|
x,
|
|
mask=None,
|
|
encoder_states=None,
|
|
encoder_mask=None,
|
|
pos_bias=None):
|
|
e = pos_bias if self.shared_pos else self.pos_embedding(
|
|
x.shape[1], x.shape[1])
|
|
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
|
x = fp16_clamp(x + self.cross_attn(
|
|
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
|
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
|
return x
|
|
|
|
|
|
class T5RelativeEmbedding(nn.Module):
|
|
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
|
super().__init__()
|
|
self.num_buckets = num_buckets
|
|
self.num_heads = num_heads
|
|
self.bidirectional = bidirectional
|
|
self.max_dist = max_dist
|
|
|
|
# layers
|
|
self.embedding = nn.Embedding(num_buckets, num_heads)
|
|
|
|
def __call__(self, lq, lk):
|
|
# Create relative position matrix
|
|
positions_q = mx.arange(lq)[:, None]
|
|
positions_k = mx.arange(lk)[None, :]
|
|
rel_pos = positions_k - positions_q
|
|
|
|
# Apply bucketing
|
|
rel_pos = self._relative_position_bucket(rel_pos)
|
|
|
|
# Get embeddings
|
|
rel_pos_embeds = self.embedding(rel_pos)
|
|
|
|
# Reshape to [1, N, Lq, Lk]
|
|
rel_pos_embeds = mx.transpose(rel_pos_embeds, (2, 0, 1))
|
|
rel_pos_embeds = mx.expand_dims(rel_pos_embeds, 0)
|
|
|
|
return rel_pos_embeds
|
|
|
|
def _relative_position_bucket(self, rel_pos):
|
|
# preprocess
|
|
if self.bidirectional:
|
|
num_buckets = self.num_buckets // 2
|
|
rel_buckets = mx.array(rel_pos > 0, dtype=mx.int32) * num_buckets
|
|
rel_pos = mx.abs(rel_pos)
|
|
else:
|
|
num_buckets = self.num_buckets
|
|
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
|
rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos))
|
|
|
|
# embeddings for small and large positions
|
|
max_exact = num_buckets // 2
|
|
is_small = rel_pos < max_exact
|
|
|
|
# For large positions, use log scale
|
|
rel_pos_large = max_exact + (
|
|
mx.log(mx.array(rel_pos, dtype=mx.float32) / max_exact) /
|
|
math.log(self.max_dist / max_exact) *
|
|
(num_buckets - max_exact)
|
|
).astype(mx.int32)
|
|
|
|
rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1)
|
|
|
|
# Combine small and large position buckets
|
|
rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large)
|
|
|
|
return rel_buckets
|
|
|
|
|
|
class T5Encoder(nn.Module):
|
|
def __init__(self,
|
|
vocab,
|
|
dim,
|
|
dim_attn,
|
|
dim_ffn,
|
|
num_heads,
|
|
num_layers,
|
|
num_buckets,
|
|
shared_pos=True,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.dim_ffn = dim_ffn
|
|
self.num_heads = num_heads
|
|
self.num_layers = num_layers
|
|
self.num_buckets = num_buckets
|
|
self.shared_pos = shared_pos
|
|
|
|
# layers
|
|
if isinstance(vocab, nn.Embedding):
|
|
self.token_embedding = vocab
|
|
else:
|
|
self.token_embedding = nn.Embedding(vocab, dim)
|
|
|
|
self.pos_embedding = T5RelativeEmbedding(
|
|
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.blocks = [
|
|
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
|
shared_pos, dropout) for _ in range(num_layers)
|
|
]
|
|
self.norm = T5LayerNorm(dim)
|
|
|
|
def __call__(self, ids, mask=None):
|
|
x = self.token_embedding(ids)
|
|
x = self.dropout(x)
|
|
e = self.pos_embedding(x.shape[1],
|
|
x.shape[1]) if self.shared_pos else None
|
|
for block in self.blocks:
|
|
x = block(x, mask, pos_bias=e)
|
|
x = self.norm(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class T5Decoder(nn.Module):
|
|
def __init__(self,
|
|
vocab,
|
|
dim,
|
|
dim_attn,
|
|
dim_ffn,
|
|
num_heads,
|
|
num_layers,
|
|
num_buckets,
|
|
shared_pos=True,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.dim_ffn = dim_ffn
|
|
self.num_heads = num_heads
|
|
self.num_layers = num_layers
|
|
self.num_buckets = num_buckets
|
|
self.shared_pos = shared_pos
|
|
|
|
# layers
|
|
if isinstance(vocab, nn.Embedding):
|
|
self.token_embedding = vocab
|
|
else:
|
|
self.token_embedding = nn.Embedding(vocab, dim)
|
|
|
|
self.pos_embedding = T5RelativeEmbedding(
|
|
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.blocks = [
|
|
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
|
shared_pos, dropout) for _ in range(num_layers)
|
|
]
|
|
self.norm = T5LayerNorm(dim)
|
|
|
|
def __call__(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
|
b, s = ids.shape
|
|
|
|
# causal mask
|
|
if mask is None:
|
|
mask = mx.tril(mx.ones((1, s, s)))
|
|
elif mask.ndim == 2:
|
|
# Expand mask properly
|
|
mask = mx.tril(mx.expand_dims(mask, 1).broadcast_to((b, s, s)))
|
|
|
|
# layers
|
|
x = self.token_embedding(ids)
|
|
x = self.dropout(x)
|
|
e = self.pos_embedding(x.shape[1],
|
|
x.shape[1]) if self.shared_pos else None
|
|
for block in self.blocks:
|
|
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
|
x = self.norm(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class T5Model(nn.Module):
|
|
def __init__(self,
|
|
vocab_size,
|
|
dim,
|
|
dim_attn,
|
|
dim_ffn,
|
|
num_heads,
|
|
encoder_layers,
|
|
decoder_layers,
|
|
num_buckets,
|
|
shared_pos=True,
|
|
dropout=0.0):
|
|
super().__init__()
|
|
self.vocab_size = vocab_size
|
|
self.dim = dim
|
|
self.dim_attn = dim_attn
|
|
self.dim_ffn = dim_ffn
|
|
self.num_heads = num_heads
|
|
self.encoder_layers = encoder_layers
|
|
self.decoder_layers = decoder_layers
|
|
self.num_buckets = num_buckets
|
|
|
|
# layers
|
|
self.token_embedding = nn.Embedding(vocab_size, dim)
|
|
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
|
num_heads, encoder_layers, num_buckets,
|
|
shared_pos, dropout)
|
|
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
|
num_heads, decoder_layers, num_buckets,
|
|
shared_pos, dropout)
|
|
self.head = nn.Linear(dim, vocab_size, bias=False)
|
|
|
|
def __call__(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
|
x = self.encoder(encoder_ids, encoder_mask)
|
|
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
def init_mlx_weights(module, key):
|
|
"""Initialize weights for T5 model components to match PyTorch initialization"""
|
|
|
|
def normal(key, shape, std=1.0):
|
|
return mx.random.normal(key, shape) * std
|
|
|
|
if isinstance(module, T5LayerNorm):
|
|
module.weight = mx.ones_like(module.weight)
|
|
elif isinstance(module, nn.Embedding):
|
|
key = mx.random.split(key, 1)[0]
|
|
module.weight = normal(key, module.weight.shape, std=1.0)
|
|
elif isinstance(module, T5FeedForward):
|
|
# Match PyTorch initialization
|
|
key1, key2, key3 = mx.random.split(key, 3)
|
|
module.gate_proj.weight = normal(key1, module.gate_proj.weight.shape,
|
|
std=module.dim**-0.5)
|
|
module.fc1.weight = normal(key2, module.fc1.weight.shape,
|
|
std=module.dim**-0.5)
|
|
module.fc2.weight = normal(key3, module.fc2.weight.shape,
|
|
std=module.dim_ffn**-0.5)
|
|
elif isinstance(module, T5Attention):
|
|
# Match PyTorch initialization
|
|
key1, key2, key3, key4 = random.split(key, 4)
|
|
module.q.weight = normal(key1, module.q.weight.shape,
|
|
std=(module.dim * module.dim_attn)**-0.5)
|
|
module.k.weight = normal(key2, module.k.weight.shape,
|
|
std=module.dim**-0.5)
|
|
module.v.weight = normal(key3, module.v.weight.shape,
|
|
std=module.dim**-0.5)
|
|
module.o.weight = normal(key4, module.o.weight.shape,
|
|
std=(module.num_heads * module.dim_attn)**-0.5)
|
|
elif isinstance(module, T5RelativeEmbedding):
|
|
key = mx.random.split(key, 1)[0]
|
|
module.embedding.weight = normal(key, module.embedding.weight.shape,
|
|
std=(2 * module.num_buckets * module.num_heads)**-0.5)
|
|
elif isinstance(module, nn.Linear):
|
|
# Generic linear layer initialization
|
|
key = mx.random.split(key, 1)[0]
|
|
fan_in = module.weight.shape[1]
|
|
bound = 1.0 / math.sqrt(fan_in)
|
|
module.weight = mx.random.uniform(key, module.weight.shape, -bound, bound)
|
|
|
|
return module
|
|
|
|
|
|
def _t5(name,
|
|
encoder_only=False,
|
|
decoder_only=False,
|
|
return_tokenizer=False,
|
|
tokenizer_kwargs={},
|
|
**kwargs):
|
|
# sanity check
|
|
assert not (encoder_only and decoder_only)
|
|
|
|
# params
|
|
if encoder_only:
|
|
model_cls = T5Encoder
|
|
kwargs['vocab'] = kwargs.pop('vocab_size')
|
|
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
|
_ = kwargs.pop('decoder_layers')
|
|
elif decoder_only:
|
|
model_cls = T5Decoder
|
|
kwargs['vocab'] = kwargs.pop('vocab_size')
|
|
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
|
_ = kwargs.pop('encoder_layers')
|
|
else:
|
|
model_cls = T5Model
|
|
|
|
# init model
|
|
model = model_cls(**kwargs)
|
|
|
|
# Initialize weights properly
|
|
key = mx.random.key(0)
|
|
model = init_mlx_weights(model, key)
|
|
|
|
# init tokenizer
|
|
if return_tokenizer:
|
|
from .tokenizers import HuggingfaceTokenizer
|
|
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
|
return model, tokenizer
|
|
else:
|
|
return model
|
|
|
|
|
|
def umt5_xxl(**kwargs):
|
|
cfg = dict(
|
|
vocab_size=256384,
|
|
dim=4096,
|
|
dim_attn=4096,
|
|
dim_ffn=10240,
|
|
num_heads=64,
|
|
encoder_layers=24,
|
|
decoder_layers=24,
|
|
num_buckets=32,
|
|
shared_pos=False,
|
|
dropout=0.0)
|
|
cfg.update(**kwargs)
|
|
return _t5('umt5-xxl', **cfg)
|
|
|
|
|
|
class T5EncoderModel:
|
|
def __init__(
|
|
self,
|
|
text_len,
|
|
checkpoint_path=None,
|
|
tokenizer_path=None,
|
|
):
|
|
self.text_len = text_len
|
|
self.checkpoint_path = checkpoint_path
|
|
self.tokenizer_path = tokenizer_path
|
|
|
|
# init model
|
|
model = umt5_xxl(
|
|
encoder_only=True,
|
|
return_tokenizer=False)
|
|
|
|
if checkpoint_path:
|
|
logging.info(f'loading {checkpoint_path}')
|
|
# Load weights - assuming MLX format checkpoint
|
|
weights = mx.load(checkpoint_path)
|
|
model.update(tree_unflatten(list(weights.items())))
|
|
|
|
self.model = model
|
|
|
|
# init tokenizer
|
|
from .tokenizers import HuggingfaceTokenizer
|
|
self.tokenizer = HuggingfaceTokenizer(
|
|
name=tokenizer_path if tokenizer_path else 'google/umt5-xxl',
|
|
seq_len=text_len,
|
|
clean='whitespace')
|
|
|
|
def __call__(self, texts):
|
|
# Handle single string input
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
# Tokenize texts
|
|
tokenizer_output = self.tokenizer(
|
|
texts, return_mask=True, add_special_tokens=True)
|
|
|
|
# Handle different tokenizer output formats
|
|
if isinstance(tokenizer_output, tuple):
|
|
ids, mask = tokenizer_output
|
|
else:
|
|
# Assuming dict output with 'input_ids' and 'attention_mask'
|
|
ids = tokenizer_output['input_ids']
|
|
mask = tokenizer_output['attention_mask']
|
|
|
|
# Convert to MLX arrays if not already
|
|
if not isinstance(ids, mx.array):
|
|
ids = mx.array(ids)
|
|
if not isinstance(mask, mx.array):
|
|
mask = mx.array(mask)
|
|
|
|
# Get sequence lengths
|
|
seq_lens = mx.sum(mask > 0, axis=1)
|
|
|
|
# Run encoder
|
|
context = self.model(ids, mask)
|
|
|
|
# Return variable length outputs
|
|
# Convert seq_lens to Python list for indexing
|
|
if seq_lens.ndim == 0: # Single value
|
|
seq_lens_list = [seq_lens.item()]
|
|
else:
|
|
seq_lens_list = seq_lens.tolist()
|
|
|
|
return [context[i, :int(seq_lens_list[i])] for i in range(len(texts))]
|
|
|
|
|
|
# Utility function to convert PyTorch checkpoint to MLX
|
|
def convert_pytorch_checkpoint(pytorch_path, mlx_path):
|
|
"""Convert PyTorch checkpoint to MLX format"""
|
|
import torch
|
|
|
|
# Load PyTorch checkpoint
|
|
pytorch_state = torch.load(pytorch_path, map_location='cpu')
|
|
|
|
# Convert to numpy then to MLX
|
|
mlx_state = {}
|
|
for key, value in pytorch_state.items():
|
|
if isinstance(value, torch.Tensor):
|
|
# Handle the key mapping if needed
|
|
mlx_key = key
|
|
# Convert tensor to MLX array
|
|
mlx_state[mlx_key] = mx.array(value.numpy())
|
|
|
|
# Save MLX checkpoint
|
|
mx.save(mlx_path, mlx_state)
|
|
|
|
return mlx_state |