mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 17:41:11 +08:00
245 lines
8.5 KiB
Python
245 lines
8.5 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
_SHARED_REPLACEMENT_PATTERNS = [
|
|
(".block.", ".layers."),
|
|
(".k.", ".key_proj."),
|
|
(".o.", ".out_proj."),
|
|
(".q.", ".query_proj."),
|
|
(".v.", ".value_proj."),
|
|
("shared.", "wte."),
|
|
("lm_head.", "lm_head.linear."),
|
|
(".layer.0.layer_norm.", ".ln1."),
|
|
(".layer.1.layer_norm.", ".ln2."),
|
|
(".layer.2.layer_norm.", ".ln3."),
|
|
(".final_layer_norm.", ".ln."),
|
|
(
|
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
"relative_attention_bias.embeddings.",
|
|
),
|
|
]
|
|
|
|
_ENCODER_REPLACEMENT_PATTERNS = [
|
|
(".layer.0.SelfAttention.", ".attention."),
|
|
(".layer.1.DenseReluDense.", ".dense."),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class T5Config:
|
|
vocab_size: int
|
|
num_layers: int
|
|
num_heads: int
|
|
relative_attention_num_buckets: int
|
|
d_kv: int
|
|
d_model: int
|
|
feed_forward_proj: str
|
|
tie_word_embeddings: bool
|
|
|
|
d_ff: Optional[int] = None
|
|
num_decoder_layers: Optional[int] = None
|
|
relative_attention_max_distance: int = 128
|
|
layer_norm_epsilon: float = 1e-6
|
|
|
|
@classmethod
|
|
def from_dict(cls, config):
|
|
return cls(
|
|
vocab_size=config["vocab_size"],
|
|
num_layers=config["num_layers"],
|
|
num_heads=config["num_heads"],
|
|
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
|
d_kv=config["d_kv"],
|
|
d_model=config["d_model"],
|
|
feed_forward_proj=config["feed_forward_proj"],
|
|
tie_word_embeddings=config["tie_word_embeddings"],
|
|
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
|
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
|
relative_attention_max_distance=config.get(
|
|
"relative_attention_max_distance", 128
|
|
),
|
|
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
|
)
|
|
|
|
|
|
class RelativePositionBias(nn.Module):
|
|
def __init__(self, config: T5Config, bidirectional: bool):
|
|
self.bidirectional = bidirectional
|
|
self.num_buckets = config.relative_attention_num_buckets
|
|
self.max_distance = config.relative_attention_max_distance
|
|
self.n_heads = config.num_heads
|
|
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
|
|
|
@staticmethod
|
|
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
|
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
|
max_exact = num_buckets // 2
|
|
|
|
abspos = rpos.abs()
|
|
is_small = abspos < max_exact
|
|
|
|
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
|
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
|
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
|
|
|
buckets = mx.where(is_small, abspos, buckets_large)
|
|
if bidirectional:
|
|
buckets = buckets + (rpos > 0) * num_buckets
|
|
else:
|
|
buckets = buckets * (rpos < 0)
|
|
|
|
return buckets
|
|
|
|
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
|
"""Compute binned relative position bias"""
|
|
context_position = mx.arange(offset, query_length)[:, None]
|
|
memory_position = mx.arange(key_length)[None, :]
|
|
|
|
# shape (query_length, key_length)
|
|
relative_position = memory_position - context_position
|
|
relative_position_bucket = self._relative_position_bucket(
|
|
relative_position,
|
|
bidirectional=self.bidirectional,
|
|
num_buckets=self.num_buckets,
|
|
max_distance=self.max_distance,
|
|
)
|
|
|
|
# shape (query_length, key_length, num_heads)
|
|
values = self.embeddings(relative_position_bucket)
|
|
|
|
# shape (num_heads, query_length, key_length)
|
|
return values.transpose(2, 0, 1)
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, config: T5Config):
|
|
super().__init__()
|
|
inner_dim = config.d_kv * config.num_heads
|
|
self.num_heads = config.num_heads
|
|
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
|
|
|
def __call__(
|
|
self,
|
|
queries: mx.array,
|
|
keys: mx.array,
|
|
values: mx.array,
|
|
mask: Optional[mx.array],
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
|
queries = self.query_proj(queries)
|
|
keys = self.key_proj(keys)
|
|
values = self.value_proj(values)
|
|
|
|
num_heads = self.num_heads
|
|
B, L, _ = queries.shape
|
|
_, S, _ = keys.shape
|
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
key_cache, value_cache = cache
|
|
keys = mx.concatenate([key_cache, keys], axis=3)
|
|
values = mx.concatenate([value_cache, values], axis=2)
|
|
|
|
values_hat = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
|
)
|
|
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
|
|
return self.out_proj(values_hat), (keys, values)
|
|
|
|
|
|
class DenseActivation(nn.Module):
|
|
def __init__(self, config: T5Config):
|
|
super().__init__()
|
|
mlp_dims = config.d_ff or config.d_model * 4
|
|
self.gated = config.feed_forward_proj.startswith("gated")
|
|
if self.gated:
|
|
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
else:
|
|
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
|
activation = config.feed_forward_proj.removeprefix("gated-")
|
|
if activation == "relu":
|
|
self.act = nn.relu
|
|
elif activation == "gelu":
|
|
self.act = nn.gelu
|
|
elif activation == "silu":
|
|
self.act = nn.silu
|
|
else:
|
|
raise ValueError(f"Unknown activation: {activation}")
|
|
|
|
def __call__(self, x):
|
|
if self.gated:
|
|
hidden_act = self.act(self.wi_0(x))
|
|
hidden_linear = self.wi_1(x)
|
|
x = hidden_act * hidden_linear
|
|
else:
|
|
x = self.act(self.wi(x))
|
|
return self.wo(x)
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
def __init__(self, config: T5Config):
|
|
super().__init__()
|
|
self.attention = MultiHeadAttention(config)
|
|
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
self.dense = DenseActivation(config)
|
|
|
|
def __call__(self, x, mask):
|
|
y = self.ln1(x)
|
|
y, _ = self.attention(y, y, y, mask=mask)
|
|
x = x + y
|
|
|
|
y = self.ln2(x)
|
|
y = self.dense(y)
|
|
return x + y
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
def __init__(self, config: T5Config):
|
|
super().__init__()
|
|
self.layers = [
|
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
|
]
|
|
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
|
|
|
def __call__(self, x: mx.array):
|
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
|
pos_bias = pos_bias.astype(x.dtype)
|
|
for layer in self.layers:
|
|
x = layer(x, mask=pos_bias)
|
|
return self.ln(x)
|
|
|
|
|
|
class T5Encoder(nn.Module):
|
|
def __init__(self, config: T5Config):
|
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
|
self.encoder = TransformerEncoder(config)
|
|
|
|
def sanitize(self, weights):
|
|
new_weights = {}
|
|
for k, w in weights.items():
|
|
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
|
k = k.replace(old, new)
|
|
if k.startswith("encoder."):
|
|
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
|
k = k.replace(old, new)
|
|
new_weights[k] = w
|
|
return new_weights
|
|
|
|
def __call__(self, inputs: mx.array):
|
|
return self.encoder(self.wte(inputs))
|