2023-12-15 04:21:36 +08:00
import argparse
2023-12-18 09:35:53 +08:00
from typing import Optional , Tuple , List
2023-12-17 23:53:49 +08:00
from time import perf_counter_ns
2023-12-15 04:21:36 +08:00
2023-12-16 05:51:01 +08:00
import numpy as np
2023-12-15 04:21:36 +08:00
import mlx . core as mx
import mlx . nn as nn
2023-12-15 04:38:41 +08:00
from mlx . utils import tree_flatten , tree_unflatten
2023-12-18 21:42:06 +08:00
from transformers import T5Config , T5Tokenizer
2023-12-17 03:53:50 +08:00
def _relative_position_bucket (
relative_position , bidirectional = True , num_buckets = 32 , max_distance = 128
) :
2023-12-15 23:16:11 +08:00
"""
Adapted from HF Tensorflow :
https : / / github . com / huggingface / transformers / blob / main / src / transformers / models / t5 / modeling_t5 . py
Translate relative position to a bucket number for relative attention . The relative position is defined as
memory_position - query_position , i . e . the distance in tokens from the attending position to the attended - to
position . If bidirectional = False , then positive relative positions are invalid . We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions . All relative
positions > = max_distance map to the same bucket . All relative positions < = - max_distance map to the same bucket .
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args :
relative_position : an int32 Tensor
bidirectional : a boolean - whether the attention is bidirectional
num_buckets : an integer
max_distance : an integer
Returns :
a Tensor with the same shape as relative_position , containing int32 values in the range [ 0 , num_buckets )
"""
relative_buckets = 0
if bidirectional :
num_buckets / / = 2
2023-12-16 05:51:01 +08:00
relative_buckets + = ( relative_position > 0 ) . astype ( mx . int16 ) * num_buckets
2023-12-15 23:16:11 +08:00
relative_position = mx . abs ( relative_position )
else :
2023-12-18 13:30:28 +08:00
relative_position = - mx . minimum (
relative_position , mx . zeros_like ( relative_position )
)
2023-12-15 23:16:11 +08:00
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets / / 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
2023-12-19 00:13:44 +08:00
scale = ( num_buckets - max_exact ) / np . log ( max_distance / max_exact )
2023-12-15 23:16:11 +08:00
relative_position_if_large = max_exact + (
2023-12-18 13:30:28 +08:00
mx . log ( relative_position . astype ( mx . float32 ) / max_exact ) * scale
2023-12-16 05:51:01 +08:00
) . astype ( mx . int16 )
2023-12-17 03:53:50 +08:00
relative_position_if_large = mx . minimum ( relative_position_if_large , num_buckets - 1 )
relative_buckets + = mx . where (
is_small , relative_position , relative_position_if_large
2023-12-15 23:16:11 +08:00
)
return relative_buckets
class RelativePositionBias ( nn . Module ) :
2023-12-18 21:42:27 +08:00
def __init__ ( self , config : T5Config , bidirectional : bool ) :
self . bidirectional = bidirectional
2023-12-15 23:16:11 +08:00
self . num_buckets = config . relative_attention_num_buckets
2023-12-17 00:18:17 +08:00
self . max_distance = config . relative_attention_max_distance
2023-12-15 23:16:11 +08:00
self . n_heads = config . num_heads
2023-12-16 05:51:01 +08:00
self . embeddings = nn . Embedding (
2023-12-17 03:53:50 +08:00
config . relative_attention_num_buckets , config . num_heads
)
2023-12-15 23:16:11 +08:00
2023-12-18 13:22:00 +08:00
def __call__ ( self , query_length : int , key_length : int , offset : int = 0 ) :
2023-12-15 23:16:11 +08:00
""" Compute binned relative position bias """
2023-12-18 13:22:00 +08:00
context_position = mx . arange ( offset , query_length ) [ : , None ]
memory_position = mx . arange ( key_length ) [ None , : ]
# shape (query_length, key_length)
2023-12-18 13:30:28 +08:00
relative_position = memory_position - context_position
2023-12-15 23:16:11 +08:00
relative_position_bucket = _relative_position_bucket (
2023-12-18 13:22:00 +08:00
relative_position ,
2023-12-15 23:16:11 +08:00
bidirectional = self . bidirectional ,
2023-12-16 05:51:01 +08:00
num_buckets = self . num_buckets ,
max_distance = self . max_distance ,
2023-12-15 23:16:11 +08:00
)
2023-12-18 13:22:00 +08:00
# shape (query_length, key_length, num_heads)
2023-12-18 13:30:28 +08:00
values = self . embeddings ( relative_position_bucket )
2023-12-18 13:22:00 +08:00
# shape (num_heads, query_length, key_length)
return values . transpose ( 2 , 0 , 1 )
2023-12-15 23:16:11 +08:00
class MultiHeadAttention ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 23:16:11 +08:00
super ( ) . __init__ ( )
2023-12-16 00:30:17 +08:00
self . num_heads = config . num_heads
self . query_proj = nn . Linear ( config . d_model , config . d_model , bias = False )
self . key_proj = nn . Linear ( config . d_model , config . d_model , bias = False )
self . value_proj = nn . Linear ( config . d_model , config . d_model , bias = False )
self . out_proj = nn . Linear ( config . d_model , config . d_model , bias = False )
2023-12-15 23:16:11 +08:00
2023-12-18 09:35:53 +08:00
def __call__ (
self ,
queries : mx . array ,
keys : mx . array ,
values : mx . array ,
2023-12-19 00:39:17 +08:00
mask : Optional [ mx . array ] ,
2023-12-18 09:35:53 +08:00
cache : Optional [ Tuple [ mx . array , mx . array ] ] = None ,
2023-12-18 13:30:28 +08:00
) - > [ mx . array , Tuple [ mx . array , mx . array ] ] :
2023-12-15 23:16:11 +08:00
queries = self . query_proj ( queries )
keys = self . key_proj ( keys )
values = self . value_proj ( values )
num_heads = self . num_heads
2023-12-17 21:34:21 +08:00
B , L , _ = queries . shape
2023-12-15 23:16:11 +08:00
_ , S , _ = keys . shape
queries = queries . reshape ( B , L , num_heads , - 1 ) . transpose ( 0 , 2 , 1 , 3 )
keys = keys . reshape ( B , S , num_heads , - 1 ) . transpose ( 0 , 2 , 3 , 1 )
values = values . reshape ( B , S , num_heads , - 1 ) . transpose ( 0 , 2 , 1 , 3 )
2023-12-18 09:35:53 +08:00
if cache is not None :
key_cache , value_cache = cache
keys = mx . concatenate ( [ key_cache , keys ] , axis = 3 )
values = mx . concatenate ( [ value_cache , values ] , axis = 2 )
2023-12-15 23:16:11 +08:00
# Dimensions are [batch x num heads x sequence x hidden dim]
2023-12-17 03:24:13 +08:00
scores = queries @ keys
2023-12-15 23:16:11 +08:00
if mask is not None :
scores = scores + mask . astype ( scores . dtype )
2023-12-16 05:51:01 +08:00
2023-12-15 23:16:11 +08:00
scores = mx . softmax ( scores , axis = - 1 )
values_hat = ( scores @ values ) . transpose ( 0 , 2 , 1 , 3 ) . reshape ( B , L , - 1 )
2023-12-18 09:35:53 +08:00
return self . out_proj ( values_hat ) , ( keys , values )
2023-12-15 04:21:36 +08:00
2023-12-15 04:38:41 +08:00
2023-12-18 09:35:53 +08:00
class RMSNorm ( nn . Module ) :
2023-12-17 20:47:52 +08:00
def __init__ ( self , dims : int , eps : float = 1e-5 ) :
2023-12-15 04:38:41 +08:00
super ( ) . __init__ ( )
2023-12-17 20:47:52 +08:00
self . weight = mx . ones ( ( dims , ) )
2023-12-15 04:38:41 +08:00
self . eps = eps
2023-12-18 09:35:53 +08:00
def _norm ( self , x ) :
return x * mx . rsqrt ( x . square ( ) . mean ( - 1 , keepdims = True ) + self . eps )
2023-12-15 04:38:41 +08:00
def __call__ ( self , x ) :
2023-12-18 09:35:53 +08:00
output = self . _norm ( x . astype ( mx . float32 ) ) . astype ( x . dtype )
return self . weight * output
2023-12-15 04:38:41 +08:00
class TransformerEncoderLayer ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 04:38:41 +08:00
super ( ) . __init__ ( )
2023-12-15 04:51:03 +08:00
mlp_dims = config . d_ff or config . d_model * 4
2023-12-18 09:35:53 +08:00
self . attention = MultiHeadAttention ( config )
self . ln1 = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
self . ln2 = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
2023-12-15 04:51:03 +08:00
self . linear1 = nn . Linear ( config . d_model , mlp_dims , bias = False )
self . linear2 = nn . Linear ( mlp_dims , config . d_model , bias = False )
2023-12-15 04:38:41 +08:00
2023-12-18 09:35:53 +08:00
def __call__ ( self , x , mask ) :
2023-12-15 04:38:41 +08:00
y = self . ln1 ( x )
2023-12-18 09:35:53 +08:00
y , _ = self . attention ( y , y , y , mask = mask )
2023-12-15 04:38:41 +08:00
x = x + y
y = self . ln2 ( x )
y = self . linear1 ( y )
y = mx . maximum ( y , 0 )
y = self . linear2 ( y )
2023-12-18 13:30:28 +08:00
return x + y
2023-12-15 04:38:41 +08:00
class TransformerEncoder ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 04:38:41 +08:00
super ( ) . __init__ ( )
self . layers = [
2023-12-18 09:35:53 +08:00
TransformerEncoderLayer ( config ) for i in range ( config . num_layers )
2023-12-15 04:38:41 +08:00
]
2023-12-18 09:35:53 +08:00
self . ln = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
2023-12-18 13:22:00 +08:00
self . relative_attention_bias = RelativePositionBias ( config , bidirectional = True )
2023-12-15 04:38:41 +08:00
2023-12-18 13:30:28 +08:00
def __call__ ( self , x : mx . array ) :
2023-12-18 09:35:53 +08:00
pos_bias = self . relative_attention_bias ( x . shape [ 1 ] , x . shape [ 1 ] )
2023-12-15 04:38:41 +08:00
for layer in self . layers :
2023-12-18 09:35:53 +08:00
x = layer ( x , mask = pos_bias )
return self . ln ( x )
2023-12-15 04:38:41 +08:00
2023-12-15 23:50:04 +08:00
class TransformerDecoderLayer ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 23:50:04 +08:00
super ( ) . __init__ ( )
mlp_dims = config . d_ff or config . d_model * 4
2023-12-18 09:35:53 +08:00
self . self_attention = MultiHeadAttention ( config )
2023-12-16 00:30:17 +08:00
self . cross_attention = MultiHeadAttention ( config )
2023-12-18 09:35:53 +08:00
self . ln1 = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
self . ln2 = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
self . ln3 = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
2023-12-15 23:50:04 +08:00
self . linear1 = nn . Linear ( config . d_model , mlp_dims , bias = False )
self . linear2 = nn . Linear ( mlp_dims , config . d_model , bias = False )
2023-12-18 09:35:53 +08:00
def __call__ (
self ,
x : mx . array ,
memory : mx . array ,
mask : mx . array ,
memory_mask : mx . array ,
2023-12-18 13:30:28 +08:00
cache : Optional [ List [ Tuple [ mx . array , mx . array ] ] ] = None ,
) :
2023-12-15 23:50:04 +08:00
y = self . ln1 ( x )
2023-12-18 09:35:53 +08:00
y , cache = self . self_attention ( y , y , y , mask , cache )
2023-12-15 23:50:04 +08:00
x = x + y
y = self . ln2 ( x )
2023-12-19 04:05:05 +08:00
y , _ = self . cross_attention ( y , memory , memory , memory_mask )
2023-12-15 23:50:04 +08:00
x = x + y
y = self . ln3 ( x )
y = self . linear1 ( y )
y = mx . maximum ( y , 0 )
y = self . linear2 ( y )
x = x + y
2023-12-18 09:35:53 +08:00
return x , cache
2023-12-15 23:50:04 +08:00
class TransformerDecoder ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 23:50:04 +08:00
super ( ) . __init__ ( )
self . layers = [
2023-12-18 09:35:53 +08:00
TransformerDecoderLayer ( config ) for i in range ( config . num_layers )
2023-12-15 23:50:04 +08:00
]
2023-12-18 09:35:53 +08:00
self . ln = RMSNorm ( config . d_model , eps = config . layer_norm_epsilon )
2023-12-18 13:22:00 +08:00
self . relative_attention_bias = RelativePositionBias ( config , bidirectional = False )
2023-12-15 23:50:04 +08:00
2023-12-18 09:35:53 +08:00
def __call__ ( self , x , memory , mask , memory_mask , cache = None ) :
if cache is not None :
offset = cache [ 0 ] [ 0 ] . shape [ 3 ]
else :
offset = 0
cache = [ None ] * len ( self . layers )
T = offset + x . shape [ 1 ]
2023-12-18 13:22:00 +08:00
pos_bias = self . relative_attention_bias ( T , T , offset = offset )
2023-12-18 09:35:53 +08:00
if mask is not None :
mask + = pos_bias
else :
mask = pos_bias
for e , layer in enumerate ( self . layers ) :
x , cache [ e ] = layer ( x , memory , mask , memory_mask , cache = cache [ e ] )
2023-12-15 23:50:04 +08:00
x = self . ln ( x )
2023-12-18 09:35:53 +08:00
return x , cache
2023-12-15 23:50:04 +08:00
2023-12-17 03:44:15 +08:00
class OutputHead ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-17 03:53:50 +08:00
self . linear = nn . Linear ( config . d_model , config . vocab_size , bias = False )
2023-12-17 03:44:15 +08:00
def __call__ ( self , inputs ) :
return self . linear ( inputs )
2023-12-15 04:21:36 +08:00
class T5 ( nn . Module ) :
2023-12-18 21:42:06 +08:00
def __init__ ( self , config : T5Config ) :
2023-12-15 04:21:36 +08:00
self . wte = nn . Embedding ( config . vocab_size , config . d_model )
2023-12-15 04:51:03 +08:00
self . encoder = TransformerEncoder ( config )
2023-12-15 23:50:04 +08:00
self . decoder = TransformerDecoder ( config )
2023-12-17 03:44:15 +08:00
self . lm_head = OutputHead ( config )
2023-12-19 02:43:03 +08:00
self . tie_word_embeddings = config . tie_word_embeddings
self . model_dim = config . d_model
2023-12-15 04:21:36 +08:00
2023-12-18 13:30:28 +08:00
def encode ( self , inputs : mx . array ) :
2023-12-18 09:35:53 +08:00
return self . encoder ( self . wte ( inputs ) )
def decode (
self ,
inputs : mx . array ,
memory : mx . array ,
2023-12-18 13:30:28 +08:00
cache = None ,
2023-12-18 09:35:53 +08:00
) :
inputs = self . wte ( inputs )
T = inputs . shape [ 1 ]
if T > 1 :
mask = nn . MultiHeadAttention . create_additive_causal_mask ( T )
mask = mask . astype ( inputs . dtype )
else :
mask = None
y , cache = self . decoder (
inputs , memory = memory , mask = mask , memory_mask = None , cache = cache
)
2023-12-19 02:43:03 +08:00
if self . tie_word_embeddings :
# Rescale output before projecting on vocab
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
y * = self . model_dim * * - 0.5
2023-12-18 09:35:53 +08:00
return self . lm_head ( y ) , cache
2023-12-15 04:21:36 +08:00
def __call__ (
self ,
inputs : mx . array ,
2023-12-17 03:44:15 +08:00
decoder_inputs : mx . array ,
2023-12-18 13:30:28 +08:00
) :
2023-12-18 21:00:01 +08:00
return self . decode ( decoder_inputs , self . encode ( inputs ) ) [ 0 ]
2023-12-15 04:21:36 +08:00
2023-12-17 03:53:50 +08:00
def generate (
inputs : mx . array , decoder_inputs : mx . array , model : T5 , temp : Optional [ float ] = 0.0
) :
def sample ( logits ) :
if temp == 0 :
return mx . argmax ( logits , axis = - 1 )
else :
return mx . random . categorical ( logits * ( 1 / temp ) )
2023-12-15 04:21:36 +08:00
2023-12-18 09:35:53 +08:00
memory = model . encode ( inputs )
cache = None
y = decoder_inputs
2023-12-17 03:53:50 +08:00
while True :
2023-12-18 09:35:53 +08:00
logits , cache = model . decode ( y [ None ] , memory , cache = cache )
y = sample ( logits [ : , - 1 , : ] )
yield y . squeeze ( )
2023-12-15 04:21:36 +08:00
2023-12-18 21:42:06 +08:00
def load_model ( model_name : str , config : T5Config ) :
model = T5 ( config )
weights = mx . load ( f " { model_name } .npz " )
2023-12-15 04:21:36 +08:00
current_weights = tree_flatten ( model . parameters ( ) )
weights_to_load = list ( weights . items ( ) )
2023-12-16 00:30:17 +08:00
current_weights_dict = dict ( current_weights )
current_weights_keys = set ( current_weights_dict . keys ( ) )
weights_to_load_dict = dict ( weights_to_load )
weights_to_load_keys = set ( weights_to_load_dict . keys ( ) )
2023-12-15 04:21:36 +08:00
print ( " Missing weights: " , sorted ( current_weights_keys - weights_to_load_keys ) )
print ( )
print ( " Weights ignored: " , sorted ( weights_to_load_keys - current_weights_keys ) )
2023-12-16 00:30:17 +08:00
for key in current_weights_keys & weights_to_load_keys :
if weights_to_load_dict [ key ] . shape != current_weights_dict [ key ] . shape :
print ( " Shape mismatch for key: " , key )
print ( " Expected shape: " , current_weights_dict [ key ] . shape )
print ( " Loading shape: " , weights_to_load_dict [ key ] . shape )
2023-12-15 04:21:36 +08:00
model . update ( tree_unflatten ( weights_to_load ) )
2023-12-18 09:35:53 +08:00
mx . eval ( model . parameters ( ) )
2023-12-18 21:42:06 +08:00
return model
2023-12-15 04:21:36 +08:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = " T5 Inference script " )
2023-12-18 21:42:06 +08:00
parser . add_argument (
" --model " ,
type = str ,
help = " Name of the T5 model. " ,
choices = [ " t5-small " , " t5-base " , " t5-large " , " t5-3b " , " t5-11b " ] ,
default = " t5-small " ,
)
2023-12-15 04:21:36 +08:00
parser . add_argument (
" --prompt " ,
2023-12-16 21:17:08 +08:00
help = " " ,
default = " translate English to German: That is good. " ,
2023-12-15 04:21:36 +08:00
)
2023-12-17 20:20:24 +08:00
parser . add_argument (
" --encode-only " ,
2023-12-18 13:30:28 +08:00
action = " store_true " ,
2023-12-17 20:20:24 +08:00
default = False ,
2023-12-18 21:21:20 +08:00
help = " Whether to decode or not. If true, will output last layer of encoder. " ,
2023-12-17 20:20:24 +08:00
)
2023-12-15 04:21:36 +08:00
parser . add_argument (
" --max_tokens " ,
" -m " ,
type = int ,
default = 100 ,
help = " Maximum number of tokens to generate " ,
)
parser . add_argument (
" --temp " ,
help = " The sampling temperature. " ,
type = float ,
default = 0.0 ,
)
parser . add_argument ( " --seed " , type = int , default = 0 , help = " The PRNG seed " )
args = parser . parse_args ( )
mx . random . seed ( args . seed )
2023-12-18 21:42:06 +08:00
config = T5Config . from_pretrained ( args . model )
model = load_model ( args . model , config )
tokenizer = T5Tokenizer . from_pretrained (
args . model ,
legacy = False ,
model_max_length = config . n_positions ,
)
2023-12-15 04:21:36 +08:00
prompt = tokenizer (
args . prompt ,
return_tensors = " np " ,
return_attention_mask = False ,
) [ " input_ids " ]
prompt = mx . array ( prompt )
2023-12-17 20:20:24 +08:00
if args . encode_only :
print ( " [INFO] Encoding with T5... " , flush = True )
2023-12-17 20:47:52 +08:00
print ( args . prompt , flush = True )
2023-12-19 00:19:44 +08:00
encoder_output = model . encode ( prompt )
2023-12-17 20:20:24 +08:00
print ( encoder_output , flush = True )
exit ( 0 )
2023-12-15 04:21:36 +08:00
print ( " [INFO] Generating with T5... " , flush = True )
2023-12-17 21:58:09 +08:00
print ( " Input: " , args . prompt , flush = True )
2023-12-15 04:21:36 +08:00
2023-12-18 09:35:53 +08:00
decoder_inputs = mx . array ( [ config . decoder_start_token_id ] )
2023-12-15 04:21:36 +08:00
2023-12-17 23:53:49 +08:00
start = perf_counter_ns ( )
2023-12-18 09:35:53 +08:00
2023-12-17 03:53:50 +08:00
tokens = [ ]
2023-12-18 09:35:53 +08:00
for token , n_tokens in zip (
2023-12-18 13:30:28 +08:00
generate ( prompt , decoder_inputs , model , args . temp ) , range ( args . max_tokens )
2023-12-17 03:53:50 +08:00
) :
2023-12-18 09:35:53 +08:00
if token . item ( ) == tokenizer . eos_token_id :
break
2023-12-18 21:09:56 +08:00
print (
tokenizer . convert_ids_to_tokens ( token . item ( ) ) . replace ( " ▁ " , " " ) ,
end = " " ,
flush = True ,
)
2023-12-15 04:21:36 +08:00
2023-12-17 23:53:49 +08:00
end = perf_counter_ns ( )
elapsed = ( end - start ) / 1.0e9
2023-12-18 09:35:53 +08:00
print ( )
print ( f " Time: { elapsed : .2f } seconds, tokens/s: { n_tokens / elapsed : .2f } " )