mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 19:06:37 +08:00
with cache
This commit is contained in:
parent
29bfb93455
commit
688a6e1e78
@ -14,10 +14,8 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
".relative_attention_bias.",
|
||||
".relative_attention_bias.embeddings."
|
||||
),
|
||||
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings."),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
@ -33,6 +31,7 @@ DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
||||
]
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
@ -45,14 +44,22 @@ def replace_key(key: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def convert():
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"t5-small", torch_dtype="auto"
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
def convert(model_name):
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
default="t5-small",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model_name)
|
||||
|
236
t5/t5.py
236
t5/t5.py
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional
|
||||
from time import perf_counter_ns
|
||||
|
||||
@ -110,18 +111,23 @@ class RelativePositionBias(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
if has_relative_attention_bias:
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, position_bias=None):
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: mx.array,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
@ -133,106 +139,95 @@ class MultiHeadAttention(nn.Module):
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scores = queries @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.relative_attention_bias(L, S)
|
||||
if position_bias is not None:
|
||||
scores += position_bias
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
out = self.out_proj(values_hat)
|
||||
return out, position_bias
|
||||
|
||||
@staticmethod
|
||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
# TODO: Should replace this with finfo(dtype).min
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = x * mx.rsqrt(var + self.eps)
|
||||
return x * self.weight
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.attention = MultiHeadAttention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias
|
||||
)
|
||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
|
||||
def __call__(self, x, mask, position_bias=None):
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, position_bias = self.attention(
|
||||
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
|
||||
)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x, position_bias
|
||||
return x + y
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
|
||||
for i in range(config.num_layers)
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
position_bias = None
|
||||
def __call__(self, x):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
for layer in self.layers:
|
||||
x, position_bias = layer(x, mask, position_bias=position_bias)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.self_attention = MultiHeadAttention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias
|
||||
)
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
y = self.ln1(x)
|
||||
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
@ -245,27 +240,39 @@ class TransformerDecoderLayer(nn.Module):
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x, position_bias
|
||||
return x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
|
||||
for i in range(config.num_layers)
|
||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
position_bias = None
|
||||
for layer in self.layers:
|
||||
x, position_bias = layer(
|
||||
x, memory, x_mask, memory_mask, position_bias=position_bias
|
||||
)
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
offset = cache[0][0].shape[3]
|
||||
else:
|
||||
offset = 0
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
# TODO, add offset to RelativePositionBias class to avoid wasted work
|
||||
pos_bias = self.relative_attention_bias(T, T)
|
||||
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
mask = pos_bias
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
@ -283,26 +290,37 @@ class T5(nn.Module):
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
inputs: mx.array
|
||||
) -> mx.array:
|
||||
return self.encoder(self.wte(inputs))
|
||||
|
||||
def decode(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
memory: mx.array,
|
||||
cache = None,
|
||||
):
|
||||
inputs = self.wte(inputs)
|
||||
T = inputs.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(inputs.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
y, cache = self.decoder(
|
||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||
)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
decoder_inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(inputs)
|
||||
y = self.encoder(x, mask=None) # , cache)
|
||||
|
||||
decoder_inputs = self.wte(decoder_inputs)
|
||||
decoder_n_tokens = decoder_inputs.shape[1]
|
||||
if decoder_n_tokens > 1 and mask is None:
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y = self.decoder(
|
||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||
) # , cache)
|
||||
return self.lm_head(y), cache
|
||||
) -> mx.array:
|
||||
return decode(decoder_inputs, encode(inputs))[0]
|
||||
|
||||
|
||||
def generate(
|
||||
@ -314,12 +332,15 @@ def generate(
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
memory = model.encode(inputs)
|
||||
cache = None
|
||||
y = decoder_inputs
|
||||
while True:
|
||||
# TODO: add cache
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
y = mx.expand_dims(sample(logits[:, -1, :]), 0)
|
||||
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1)
|
||||
yield y
|
||||
logits, cache = model.decode(y[None], memory, cache=cache)
|
||||
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
|
||||
y = sample(logits[:, -1, :])
|
||||
#decoder_inputs = mx.concatenate([decoder_inputs, y])
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_config):
|
||||
@ -340,6 +361,7 @@ def load_model(model_config):
|
||||
print("Expected shape: ", current_weights_dict[key].shape)
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
@ -397,39 +419,23 @@ if __name__ == "__main__":
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print("Input: ", args.prompt, flush=True)
|
||||
|
||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
||||
decoder_inputs = mx.array([config.decoder_start_token_id])
|
||||
|
||||
start = perf_counter_ns()
|
||||
n_tokens = 0
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
for token, n_tokens in zip(
|
||||
generate(prompt, decoder_inputs, model, args.temp),
|
||||
range(args.max_tokens)
|
||||
):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
eos_index = next(
|
||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if eos_index is not None:
|
||||
tokens = tokens[:eos_index]
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
n_tokens += len(tokens)
|
||||
tokens = []
|
||||
if eos_index is not None:
|
||||
break
|
||||
if token.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token.item())
|
||||
# For some reason using the following line doesn't give spaces
|
||||
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
||||
print(tokenizer.decode(tokens), end="", flush=True)
|
||||
|
||||
end = perf_counter_ns()
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
n_tokens += len(tokens)
|
||||
elapsed = (end - start) / 1.0e9
|
||||
print(s, flush=True)
|
||||
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
||||
print()
|
||||
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
||||
|
Loading…
Reference in New Issue
Block a user