diff --git a/t5/convert.py b/t5/convert.py index 9ff5131a..35374724 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -14,8 +14,10 @@ SHARED_REPLACEMENT_PATTERNS = [ (".layer.1.layer_norm.", ".ln2."), (".layer.2.layer_norm.", ".ln3."), (".final_layer_norm.", ".ln."), - ("layers.0.layer.0.SelfAttention.relative_attention_bias.", - "relative_attention_bias.embeddings."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), ] ENCODER_REPLACEMENT_PATTERNS = [ diff --git a/t5/t5.py b/t5/t5.py index 7196249b..dcdcfbfe 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -58,7 +58,9 @@ def _relative_position_bucket( relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets relative_position = mx.abs(relative_position) else: - relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position)) + relative_position = -mx.minimum( + relative_position, mx.zeros_like(relative_position) + ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -66,10 +68,9 @@ def _relative_position_bucket( is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + scale = np.log(max_distance / max_exact) * (num_buckets - max_exact) relative_position_if_large = max_exact + ( - mx.log(relative_position.astype(mx.float32) / max_exact) - / np.log(max_distance / max_exact) - * (num_buckets - max_exact) + mx.log(relative_position.astype(mx.float32) / max_exact) * scale ).astype(mx.int16) relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) relative_buckets += mx.where( @@ -80,7 +81,7 @@ def _relative_position_bucket( class RelativePositionBias(nn.Module): def __init__(self, config: ModelArgs, bidirectional: bool): - self.bidirectional = False #bidirectional + self.bidirectional = False # bidirectional self.num_buckets = config.relative_attention_num_buckets self.max_distance = config.relative_attention_max_distance self.n_heads = config.num_heads @@ -93,11 +94,8 @@ class RelativePositionBias(nn.Module): context_position = mx.arange(offset, query_length)[:, None] memory_position = mx.arange(key_length)[None, :] - # shape (query_length, key_length) - relative_position = ( - memory_position - context_position - ) + relative_position = memory_position - context_position relative_position_bucket = _relative_position_bucket( relative_position, bidirectional=self.bidirectional, @@ -106,9 +104,7 @@ class RelativePositionBias(nn.Module): ) # shape (query_length, key_length, num_heads) - values = self.embeddings( - relative_position_bucket - ) + values = self.embeddings(relative_position_bucket) # shape (num_heads, query_length, key_length) return values.transpose(2, 0, 1) @@ -130,8 +126,7 @@ class MultiHeadAttention(nn.Module): values: mx.array, mask: mx.array, cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - + ) -> [mx.array, Tuple[mx.array, mx.array]]: queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) @@ -191,7 +186,7 @@ class TransformerEncoderLayer(nn.Module): y = self.linear1(y) y = mx.maximum(y, 0) y = self.linear2(y) - return x + y + return x + y class TransformerEncoder(nn.Module): @@ -203,7 +198,7 @@ class TransformerEncoder(nn.Module): self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) - def __call__(self, x): + def __call__(self, x: mx.array): pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) for layer in self.layers: x = layer(x, mask=pos_bias) @@ -228,8 +223,8 @@ class TransformerDecoderLayer(nn.Module): memory: mx.array, mask: mx.array, memory_mask: mx.array, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: + cache: Optional[List[Tuple[mx.array, mx.array]]] = None, + ): y = self.ln1(x) y, cache = self.self_attention(y, y, y, mask, cache) x = x + y @@ -278,7 +273,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: + def __init__(self, config: ModelArgs): self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) def __call__(self, inputs): @@ -292,17 +287,14 @@ class T5(nn.Module): self.decoder = TransformerDecoder(config) self.lm_head = OutputHead(config) - def encode( - self, - inputs: mx.array - ) -> mx.array: + def encode(self, inputs: mx.array): return self.encoder(self.wte(inputs)) def decode( self, inputs: mx.array, memory: mx.array, - cache = None, + cache=None, ): inputs = self.wte(inputs) T = inputs.shape[1] @@ -321,7 +313,7 @@ class T5(nn.Module): self, inputs: mx.array, decoder_inputs: mx.array, - ) -> mx.array: + ): return decode(decoder_inputs, encode(inputs))[0] @@ -375,7 +367,7 @@ if __name__ == "__main__": ) parser.add_argument( "--encode-only", - action='store_true', + action="store_true", default=False, help="Whether to decode or not", ) @@ -425,14 +417,13 @@ if __name__ == "__main__": tokens = [] for token, n_tokens in zip( - generate(prompt, decoder_inputs, model, args.temp), - range(args.max_tokens) + generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens) ): if token.item() == tokenizer.eos_token_id: break tokens.append(token.item()) -# For some reason using the following line doesn't give spaces -# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) + # For some reason using the following line doesn't give spaces + # print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) print(tokenizer.decode(tokens), end="", flush=True) end = perf_counter_ns()