mlx-examples/t5/t5.py

420 lines
15 KiB
Python
Raw Normal View History

2023-12-15 04:21:36 +08:00
import argparse
from dataclasses import dataclass
2023-12-17 03:53:50 +08:00
from typing import Optional
2023-12-15 04:21:36 +08:00
2023-12-16 05:51:01 +08:00
import numpy as np
2023-12-15 04:21:36 +08:00
import mlx.core as mx
import mlx.nn as nn
2023-12-15 04:38:41 +08:00
from mlx.utils import tree_flatten, tree_unflatten
from transformers import AutoTokenizer
2023-12-15 04:21:36 +08:00
@dataclass
class ModelArgs:
2023-12-17 03:53:50 +08:00
d_ff: int = 2048
d_kv: int = 64
d_model: int = 512
dropout_rate: int = 0.1
layer_norm_epsilon: float = 1e-06
n_positions: int = 512
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
num_heads: int = 8
num_layers: int = 6
decoder_start_token_id: int = 0
eos_token_id: int = 1
pad_token_id: int = 0
vocab_size: int = 32128
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
2023-12-15 23:16:11 +08:00
"""
Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
2023-12-16 05:51:01 +08:00
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
2023-12-15 23:16:11 +08:00
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
2023-12-16 05:51:01 +08:00
mx.log(relative_position.astype(mx.float32) / max_exact)
/ np.log(max_distance / max_exact)
2023-12-15 23:16:11 +08:00
* (num_buckets - max_exact)
2023-12-16 05:51:01 +08:00
).astype(mx.int16)
2023-12-17 03:53:50 +08:00
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
2023-12-15 23:16:11 +08:00
)
return relative_buckets
class RelativePositionBias(nn.Module):
def __init__(self, config: ModelArgs, is_decoder: bool = False):
self.bidirectional = not is_decoder
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
2023-12-15 23:16:11 +08:00
self.n_heads = config.num_heads
2023-12-16 05:51:01 +08:00
self.embeddings = nn.Embedding(
2023-12-17 03:53:50 +08:00
config.relative_attention_num_buckets, config.num_heads
)
2023-12-15 23:16:11 +08:00
2023-12-16 05:51:01 +08:00
def __call__(self, query_length, key_length):
2023-12-15 23:16:11 +08:00
"""Compute binned relative position bias"""
2023-12-16 05:51:01 +08:00
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
2023-12-17 03:53:50 +08:00
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
2023-12-15 23:16:11 +08:00
relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
2023-12-16 05:51:01 +08:00
num_buckets=self.num_buckets,
max_distance=self.max_distance,
2023-12-15 23:16:11 +08:00
)
2023-12-17 03:53:50 +08:00
values = self.embeddings(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(
values.transpose(2, 0, 1), 0
) # shape (1, num_heads, query_length, key_length)
2023-12-15 23:16:11 +08:00
return values
class MultiHeadAttention(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
2023-12-15 23:16:11 +08:00
super().__init__()
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
2023-12-16 05:51:01 +08:00
self.has_relative_attention_bias = has_relative_attention_bias
if has_relative_attention_bias:
2023-12-16 05:51:01 +08:00
self.relative_attention_bias = RelativePositionBias(config)
2023-12-15 23:16:11 +08:00
2023-12-17 20:19:32 +08:00
def __call__(self, queries, keys, values, mask=None, position_bias=None):
2023-12-15 23:16:11 +08:00
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
2023-12-17 21:34:21 +08:00
# print(f"queries: {queries}, {queries.abs().sum()}")
# print(f"keys: {keys}, {keys.abs().sum()}")
# print(f"values: {values}, {values.abs().sum()}")
2023-12-15 23:16:11 +08:00
num_heads = self.num_heads
2023-12-17 21:34:21 +08:00
B, L, _ = queries.shape
2023-12-15 23:16:11 +08:00
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
2023-12-17 03:24:13 +08:00
scores = queries @ keys
2023-12-15 23:16:11 +08:00
if mask is not None:
scores = scores + mask.astype(scores.dtype)
2023-12-16 05:51:01 +08:00
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
2023-12-17 20:19:32 +08:00
if position_bias is not None:
2023-12-16 05:51:01 +08:00
scores += position_bias
2023-12-15 23:16:11 +08:00
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
2023-12-17 20:19:32 +08:00
return self.out_proj(values_hat), position_bias
2023-12-15 23:16:11 +08:00
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
2023-12-15 04:21:36 +08:00
2023-12-15 04:38:41 +08:00
class LayerNorm(nn.Module):
2023-12-17 20:47:52 +08:00
def __init__(self, dims: int, eps: float = 1e-5):
2023-12-15 04:38:41 +08:00
super().__init__()
2023-12-17 20:47:52 +08:00
self.weight = mx.ones((dims,))
2023-12-15 04:38:41 +08:00
self.eps = eps
self.dims = dims
def __call__(self, x):
2023-12-17 20:47:52 +08:00
var = x.var(axis=-1, keepdims=True)
x = x * mx.rsqrt(var + self.eps)
return x * self.weight
2023-12-15 04:38:41 +08:00
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
2023-12-15 04:38:41 +08:00
super().__init__()
2023-12-15 04:51:03 +08:00
mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
2023-12-15 04:51:03 +08:00
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
2023-12-15 04:38:41 +08:00
2023-12-17 20:19:32 +08:00
def __call__(self, x, mask, position_bias=None):
2023-12-15 04:38:41 +08:00
y = self.ln1(x)
2023-12-17 20:19:32 +08:00
y, position_bias = self.attention(
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
)
2023-12-15 04:38:41 +08:00
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
2023-12-17 20:19:32 +08:00
return x, position_bias
2023-12-15 04:38:41 +08:00
class TransformerEncoder(nn.Module):
2023-12-15 04:51:03 +08:00
def __init__(self, config: ModelArgs):
2023-12-15 04:38:41 +08:00
super().__init__()
self.layers = [
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
2023-12-15 04:38:41 +08:00
]
2023-12-15 04:51:03 +08:00
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
2023-12-15 04:38:41 +08:00
def __call__(self, x, mask):
2023-12-17 20:19:32 +08:00
position_bias = None
2023-12-15 04:38:41 +08:00
for layer in self.layers:
2023-12-17 20:19:32 +08:00
x, position_bias = layer(x, mask, position_bias=position_bias)
2023-12-15 04:38:41 +08:00
x = self.ln(x)
return x
2023-12-15 23:50:04 +08:00
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
2023-12-15 23:50:04 +08:00
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.cross_attention = MultiHeadAttention(config)
2023-12-15 23:50:04 +08:00
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, memory, x_mask, memory_mask):
y = self.ln1(x)
y = self.self_attention(y, y, y, x_mask)
x = x + y
y = self.ln2(x)
y = self.cross_attention(x, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.layers = [
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
2023-12-15 23:50:04 +08:00
]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def __call__(self, x, memory, x_mask, memory_mask):
for layer in self.layers:
x = layer(x, memory, x_mask, memory_mask)
x = self.ln(x)
return x
2023-12-17 03:44:15 +08:00
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
2023-12-17 03:53:50 +08:00
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
2023-12-17 03:44:15 +08:00
def __call__(self, inputs):
return self.linear(inputs)
2023-12-15 04:21:36 +08:00
class T5(nn.Module):
def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
2023-12-15 04:51:03 +08:00
self.encoder = TransformerEncoder(config)
2023-12-15 23:50:04 +08:00
self.decoder = TransformerDecoder(config)
2023-12-17 03:44:15 +08:00
self.lm_head = OutputHead(config)
2023-12-15 04:21:36 +08:00
def __call__(
self,
inputs: mx.array,
2023-12-17 03:44:15 +08:00
decoder_inputs: mx.array,
2023-12-15 04:21:36 +08:00
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
2023-12-17 03:53:50 +08:00
y = self.encoder(x, mask=None) # , cache)
2023-12-17 03:24:13 +08:00
2023-12-17 21:34:21 +08:00
decoder_inputs = self.wte(decoder_inputs)
decoder_n_tokens = decoder_inputs.shape[1]
if decoder_n_tokens > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
2023-12-15 04:21:36 +08:00
mask = mask.astype(x.dtype)
2023-12-17 03:53:50 +08:00
y, cache = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
2023-12-17 03:44:15 +08:00
return self.lm_head(y), cache
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
def generate(
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
logits, _ = model(inputs, decoder_inputs)
y = sample(logits[:, -1, :])
yield y
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
while True:
# logits, cache = model(y[:, None], cache=cache)
logits, _ = model(inputs, decoder_inputs)
y = sample(logits.squeeze(1))
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
yield y
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
def load_model(model_config):
model = T5(model_config)
2023-12-15 04:21:36 +08:00
weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items())
current_weights_dict = dict(current_weights)
current_weights_keys = set(current_weights_dict.keys())
weights_to_load_dict = dict(weights_to_load)
weights_to_load_keys = set(weights_to_load_dict.keys())
2023-12-15 04:21:36 +08:00
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
print()
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
for key in current_weights_keys & weights_to_load_keys:
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
print("Shape mismatch for key: ", key)
print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape)
2023-12-15 04:21:36 +08:00
model.update(tree_unflatten(weights_to_load))
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
"--prompt",
2023-12-16 21:17:08 +08:00
help="",
default="translate English to German: That is good.",
2023-12-15 04:21:36 +08:00
)
2023-12-17 20:20:24 +08:00
parser.add_argument(
"--encode-only",
action='store_true',
default=False,
help="Whether to decode or not",
)
2023-12-15 04:21:36 +08:00
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
2023-12-17 03:53:50 +08:00
config = ModelArgs()
model, tokenizer = load_model(config)
2023-12-15 04:21:36 +08:00
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
2023-12-17 20:20:24 +08:00
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
2023-12-17 20:47:52 +08:00
print(args.prompt, flush=True)
2023-12-17 20:20:24 +08:00
embeddings = model.wte(prompt)
encoder_output = model.encoder(embeddings, mask=None)
print(encoder_output, flush=True)
exit(0)
2023-12-15 04:21:36 +08:00
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)
2023-12-17 03:53:50 +08:00
decoder_inputs = mx.array([[config.decoder_start_token_id]])
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
tokens = []
for token, _ in zip(
generate(prompt, decoder_inputs, model), range(args.max_tokens)
):
tokens.append(token)
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)