From 688a6e1e7824903bc288a404291e736a1415a64b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 17 Dec 2023 17:35:53 -0800 Subject: [PATCH] with cache --- t5/convert.py | 29 ++++--- t5/t5.py | 236 ++++++++++++++++++++++++++------------------------ 2 files changed, 139 insertions(+), 126 deletions(-) diff --git a/t5/convert.py b/t5/convert.py index e10e4e45..9ff5131a 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -14,10 +14,8 @@ SHARED_REPLACEMENT_PATTERNS = [ (".layer.1.layer_norm.", ".ln2."), (".layer.2.layer_norm.", ".ln3."), (".final_layer_norm.", ".ln."), - ( - ".relative_attention_bias.", - ".relative_attention_bias.embeddings." - ), + ("layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings."), ] ENCODER_REPLACEMENT_PATTERNS = [ @@ -33,6 +31,7 @@ DECODER_REPLACEMENT_PATTERNS = [ (".layer.2.DenseReluDense.wo.", ".linear2."), ] + def replace_key(key: str) -> str: for old, new in SHARED_REPLACEMENT_PATTERNS: key = key.replace(old, new) @@ -45,14 +44,22 @@ def replace_key(key: str) -> str: return key -def convert(): - model = T5ForConditionalGeneration.from_pretrained( - "t5-small", torch_dtype="auto" - ) - state_dict = model.state_dict() - weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} +def convert(model_name): + model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") + weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} np.savez("weights.npz", **weights) if __name__ == "__main__": - convert() + import argparse + + parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") + parser.add_argument( + "--model_name", + type=str, + help="Name of the T5 model.", + choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], + default="t5-small", + ) + args = parser.parse_args() + convert(args.model_name) diff --git a/t5/t5.py b/t5/t5.py index b44f054a..38fea821 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,5 +1,6 @@ import argparse from dataclasses import dataclass +from typing import Optional, Tuple, List from typing import Optional from time import perf_counter_ns @@ -110,18 +111,23 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): - def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): + def __init__(self, config: ModelArgs): super().__init__() self.num_heads = config.num_heads self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) - self.has_relative_attention_bias = has_relative_attention_bias - if has_relative_attention_bias: - self.relative_attention_bias = RelativePositionBias(config) - def __call__(self, queries, keys, values, mask=None, position_bias=None): + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: mx.array, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + queries = self.query_proj(queries) keys = self.key_proj(keys) values = self.value_proj(values) @@ -133,106 +139,95 @@ class MultiHeadAttention(nn.Module): keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + # Dimensions are [batch x num heads x sequence x hidden dim] scores = queries @ keys if mask is not None: scores = scores + mask.astype(scores.dtype) - if self.has_relative_attention_bias: - position_bias = self.relative_attention_bias(L, S) - if position_bias is not None: - scores += position_bias scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - out = self.out_proj(values_hat) - return out, position_bias - - @staticmethod - def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): - indices = mx.arange(N) - mask = indices[:, None] < indices[None] - # usually inf but 1e9 is as good and softmax(full(1e9)) != nan - # TODO: Should replace this with finfo(dtype).min - mask = mask.astype(dtype) * -1e9 - return mask + return self.out_proj(values_hat), (keys, values) -class LayerNorm(nn.Module): +class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() self.weight = mx.ones((dims,)) self.eps = eps - self.dims = dims + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) def __call__(self, x): - var = x.var(axis=-1, keepdims=True) - x = x * mx.rsqrt(var + self.eps) - return x * self.weight + output = self._norm(x.astype(mx.float32)).astype(x.dtype) + return self.weight * output class TransformerEncoderLayer(nn.Module): - def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): + def __init__(self, config: ModelArgs): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.attention = MultiHeadAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) - self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) - def __call__(self, x, mask, position_bias=None): + def __call__(self, x, mask): y = self.ln1(x) - y, position_bias = self.attention( - queries=y, keys=y, values=y, mask=mask, position_bias=position_bias - ) + y, _ = self.attention(y, y, y, mask=mask) x = x + y y = self.ln2(x) y = self.linear1(y) y = mx.maximum(y, 0) y = self.linear2(y) - x = x + y - - return x, position_bias + return x + y class TransformerEncoder(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.layers = [ - TransformerEncoderLayer(config, has_relative_attention_bias=i == 0) - for i in range(config.num_layers) + TransformerEncoderLayer(config) for i in range(config.num_layers) ] - self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config) - def __call__(self, x, mask): - position_bias = None + def __call__(self, x): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) for layer in self.layers: - x, position_bias = layer(x, mask, position_bias=position_bias) - x = self.ln(x) - - return x + x = layer(x, mask=pos_bias) + return self.ln(x) class TransformerDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): + def __init__(self, config: ModelArgs): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.self_attention = MultiHeadAttention( - config, has_relative_attention_bias=has_relative_attention_bias - ) + self.self_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config) - self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) - def __call__(self, x, memory, x_mask, memory_mask, position_bias=None): + def __call__( + self, + x: mx.array, + memory: mx.array, + mask: mx.array, + memory_mask: mx.array, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: y = self.ln1(x) - y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias) + y, cache = self.self_attention(y, y, y, mask, cache) x = x + y y = self.ln2(x) @@ -245,27 +240,39 @@ class TransformerDecoderLayer(nn.Module): y = self.linear2(y) x = x + y - return x, position_bias + return x, cache class TransformerDecoder(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.layers = [ - TransformerDecoderLayer(config, has_relative_attention_bias=i == 0) - for i in range(config.num_layers) + TransformerDecoderLayer(config) for i in range(config.num_layers) ] - self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config) - def __call__(self, x, memory, x_mask, memory_mask): - position_bias = None - for layer in self.layers: - x, position_bias = layer( - x, memory, x_mask, memory_mask, position_bias=position_bias - ) + def __call__(self, x, memory, mask, memory_mask, cache=None): + if cache is not None: + offset = cache[0][0].shape[3] + else: + offset = 0 + cache = [None] * len(self.layers) + + T = offset + x.shape[1] + # TODO, add offset to RelativePositionBias class to avoid wasted work + pos_bias = self.relative_attention_bias(T, T) + pos_bias = pos_bias[:, :, -x.shape[1]:, :] + if mask is not None: + mask += pos_bias + else: + mask = pos_bias + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) x = self.ln(x) - return x + return x, cache class OutputHead(nn.Module): @@ -283,26 +290,37 @@ class T5(nn.Module): self.decoder = TransformerDecoder(config) self.lm_head = OutputHead(config) + def encode( + self, + inputs: mx.array + ) -> mx.array: + return self.encoder(self.wte(inputs)) + + def decode( + self, + inputs: mx.array, + memory: mx.array, + cache = None, + ): + inputs = self.wte(inputs) + T = inputs.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(inputs.dtype) + else: + mask = None + + y, cache = self.decoder( + inputs, memory=memory, mask=mask, memory_mask=None, cache=cache + ) + return self.lm_head(y), cache + def __call__( self, inputs: mx.array, decoder_inputs: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> tuple[mx.array, mx.array]: - x = self.wte(inputs) - y = self.encoder(x, mask=None) # , cache) - - decoder_inputs = self.wte(decoder_inputs) - decoder_n_tokens = decoder_inputs.shape[1] - if decoder_n_tokens > 1 and mask is None: - mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens) - mask = mask.astype(x.dtype) - - y = self.decoder( - x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None - ) # , cache) - return self.lm_head(y), cache + ) -> mx.array: + return decode(decoder_inputs, encode(inputs))[0] def generate( @@ -314,12 +332,15 @@ def generate( else: return mx.random.categorical(logits * (1 / temp)) + memory = model.encode(inputs) + cache = None + y = decoder_inputs while True: - # TODO: add cache - logits, _ = model(inputs, decoder_inputs) - y = mx.expand_dims(sample(logits[:, -1, :]), 0) - decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1) - yield y + logits, cache = model.decode(y[None], memory, cache=cache) + # logits, cache = model.decode(decoder_inputs[None], memory, cache=cache) + y = sample(logits[:, -1, :]) + #decoder_inputs = mx.concatenate([decoder_inputs, y]) + yield y.squeeze() def load_model(model_config): @@ -340,6 +361,7 @@ def load_model(model_config): print("Expected shape: ", current_weights_dict[key].shape) print("Loading shape: ", weights_to_load_dict[key].shape) model.update(tree_unflatten(weights_to_load)) + mx.eval(model.parameters()) tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) return model, tokenizer @@ -397,39 +419,23 @@ if __name__ == "__main__": print("[INFO] Generating with T5...", flush=True) print("Input: ", args.prompt, flush=True) - decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) + decoder_inputs = mx.array([config.decoder_start_token_id]) start = perf_counter_ns() - n_tokens = 0 + tokens = [] - for token, _ in zip( + for token, n_tokens in zip( generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens) ): - tokens.append(token) - - if (len(tokens) % 10) == 0: - mx.eval(tokens) - eos_index = next( - (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), - None, - ) - - if eos_index is not None: - tokens = tokens[:eos_index] - - s = tokenizer.decode([t.item() for t in tokens]) - print(s, end="", flush=True) - n_tokens += len(tokens) - tokens = [] - if eos_index is not None: - break + if token.item() == tokenizer.eos_token_id: + break + tokens.append(token.item()) +# For some reason using the following line doesn't give spaces +# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) + print(tokenizer.decode(tokens), end="", flush=True) end = perf_counter_ns() - - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - n_tokens += len(tokens) elapsed = (end - start) / 1.0e9 - print(s, flush=True) - print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}") \ No newline at end of file + print() + print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")