diff --git a/t5/t5.py b/t5/t5.py index cbecbb62..9670df16 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,5 +1,6 @@ import argparse from dataclasses import dataclass +from typing import Optional import numpy as np import mlx.core as mx @@ -10,23 +11,25 @@ from transformers import AutoTokenizer @dataclass class ModelArgs: - d_ff: int = 2048 - d_kv: int = 64 - d_model: int = 512 - dropout_rate: int = 0.1 - layer_norm_epsilon: float = 1e-06 - n_positions: int = 512 - relative_attention_num_buckets: int = 32 - relative_attention_max_distance: int = 128 - num_heads: int = 8 - num_layers: int = 6 - decoder_start_token_id: int = 0 - eos_token_id: int = 1 - pad_token_id: int = 0 - vocab_size: int = 32128 + d_ff: int = 2048 + d_kv: int = 64 + d_model: int = 512 + dropout_rate: int = 0.1 + layer_norm_epsilon: float = 1e-06 + n_positions: int = 512 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + num_heads: int = 8 + num_layers: int = 6 + decoder_start_token_id: int = 0 + eos_token_id: int = 1 + pad_token_id: int = 0 + vocab_size: int = 32128 -def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): """ Adapted from HF Tensorflow: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py @@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets / np.log(max_distance / max_exact) * (num_buckets - max_exact) ).astype(mx.int16) - relative_position_if_large = mx.minimum( - relative_position_if_large, num_buckets - 1 + relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += mx.where( + is_small, relative_position, relative_position_if_large ) - relative_buckets += mx.where(is_small, relative_position, relative_position_if_large) return relative_buckets @@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module): self.max_distance = config.relative_attention_max_distance self.n_heads = config.num_heads self.embeddings = nn.Embedding( - config.relative_attention_num_buckets, - config.num_heads) + config.relative_attention_num_buckets, config.num_heads + ) def __call__(self, query_length, key_length): """Compute binned relative position bias""" context_position = mx.arange(query_length, dtype=mx.int32)[:, None] memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) relative_position_bucket = _relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=self.bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance, ) - values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length) + values = self.embeddings( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = mx.expand_dims( + values.transpose(2, 0, 1), 0 + ) # shape (1, num_heads, query_length, key_length) return values @@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module): if self.has_relative_attention_bias: position_bias = self.relative_attention_bias(L, S) scores += position_bias - scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(values_hat) @staticmethod @@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: - self.linear = nn.Linear(config.d_model, config.vocab_size) + self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) def __call__(self, inputs): return self.linear(inputs) @@ -281,36 +288,42 @@ class T5(nn.Module): cache: mx.array = None, ) -> tuple[mx.array, mx.array]: x = self.wte(inputs) - y = self.encoder(x, mask=None) #, cache) + y = self.encoder(x, mask=None) # , cache) if x.shape[1] > 1 and mask is None: mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) decoder_inputs = self.wte(decoder_inputs) - y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache) + y, cache = self.decoder( + x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None + ) # , cache) return self.lm_head(y), cache -# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0): -# def sample(logits): -# if temp == 0: -# return mx.argmax(logits, axis=-1) -# else: -# return mx.random.categorical(logits * (1 / temp)) +def generate( + inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0 +): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) -# logits, cache = model(prompt) -# y = sample(logits[:, -1, :]) -# yield y + logits, _ = model(inputs, decoder_inputs) + y = sample(logits[:, -1, :]) + yield y -# while True: -# logits, cache = model(y[:, None], cache=cache) -# y = sample(logits.squeeze(1)) -# yield y + while True: + # logits, cache = model(y[:, None], cache=cache) + logits, _ = model(inputs, decoder_inputs) + y = sample(logits.squeeze(1)) + decoder_inputs = mx.concat(decoder_inputs, y, dim=1) + yield y -def load_model(): - model = T5(ModelArgs()) +def load_model(model_config): + model = T5(model_config) weights = mx.load("weights.npz") current_weights = tree_flatten(model.parameters()) weights_to_load = list(weights.items()) @@ -356,7 +369,8 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model() + config = ModelArgs() + model, tokenizer = load_model(config) prompt = tokenizer( args.prompt, @@ -369,18 +383,20 @@ if __name__ == "__main__": print("[INFO] Generating with T5...", flush=True) print(args.prompt, end="", flush=True) - print(model(prompt)) + decoder_inputs = mx.array([[config.decoder_start_token_id]]) - # tokens = [] - # for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - # tokens.append(token) + tokens = [] + for token, _ in zip( + generate(prompt, decoder_inputs, model), range(args.max_tokens) + ): + tokens.append(token) - # if (len(tokens) % 10) == 0: - # mx.eval(tokens) - # s = tokenizer.decode([t.item() for t in tokens]) - # print(s, end="", flush=True) - # tokens = [] + if (len(tokens) % 10) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] - # mx.eval(tokens) - # s = tokenizer.decode([t.item() for t in tokens]) - # print(s, flush=True) + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True)