diff --git a/t5/.gitignore b/t5/.gitignore new file mode 100644 index 00000000..ded9aef9 --- /dev/null +++ b/t5/.gitignore @@ -0,0 +1 @@ +*.npz diff --git a/t5/README.md b/t5/README.md new file mode 100644 index 00000000..a0cc861b --- /dev/null +++ b/t5/README.md @@ -0,0 +1,53 @@ +# T5 + +The T5 models are encoder-decoder models pre-trained on a mixture of +unsupervised and supervised tasks.[^1] These models work well on a variety of +tasks by prepending task-specific prefixes to the input, e.g.: +`translate English to German: …`, `summarize: ….`, etc. + +This example also supports the FLAN-T5 models variants.[^2] + +## Setup + +Download and convert the model: + +```sh +python convert.py --model +``` + +This will make the `.npz` file which MLX can read. + +The `` can be any of the following: + +| Model Name | Model Size | +| ---------- | ---------- +| t5-small | 60 million | +| t5-base | 220 million | +| t5-large | 770 million | +| t5-3b | 3 billion | +| t5-11b | 11 billion | + +The FLAN variants can be specified with `google/flan-t5-small`, +`google/flan-t5-base`, etc. See the [Hugging Face +page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a +complete list of models. + +## Generate + +Generate text with: + +```sh +python t5.py --model t5-small --prompt "translate English to German: A tasty apple" +``` + +This should give the output: `Ein leckerer Apfel` + +To see a list of options run: + +```sh +python t5.py --help +``` + +[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) + or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). +[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416). diff --git a/t5/convert.py b/t5/convert.py new file mode 100644 index 00000000..71b009da --- /dev/null +++ b/t5/convert.py @@ -0,0 +1,68 @@ +from transformers import T5ForConditionalGeneration +import numpy as np + + +SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + +DECODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), +] + + +def replace_key(key: str) -> str: + for old, new in SHARED_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in ENCODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in DECODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + return key + + +def convert(model_name): + model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") + weights = { + replace_key(k): v.numpy().astype(np.float16) + for k, v in model.state_dict().items() + } + file_name = model_name.replace("/", "-") + np.savez(f"{file_name}.npz", **weights) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + default="t5-small", + ) + args = parser.parse_args() + convert(args.model) diff --git a/t5/hf_t5.py b/t5/hf_t5.py new file mode 100644 index 00000000..ddd99610 --- /dev/null +++ b/t5/hf_t5.py @@ -0,0 +1,54 @@ +from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer + +import argparse + + +def embed(t5_model: str): + batch = [ + "translate English to German: That is good.", + "This is an example of T5 working on MLX.", + ] + + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5EncoderModel.from_pretrained(t5_model) + torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) + torch_forward = torch_model(**torch_tokens, output_hidden_states=True) + torch_output = torch_forward.last_hidden_state.detach().numpy() + + print("\n TF BERT:") + for input_str, embedding in list(zip(batch, torch_output)): + print("Input:", input_str) + print(embedding) + print() + + +def generate(t5_model: str): + prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast." + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5ForConditionalGeneration.from_pretrained(t5_model) + torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids + outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512) + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run the T5 model using Hugging Face Transformers." + ) + parser.add_argument( + "--encode-only", + action="store_true", + help="Only run the encoder and print the embeddings.", + default=False, + ) + parser.add_argument( + "--model", + default="t5-small", + help="The huggingface name of the T5 model to save.", + ) + args = parser.parse_args() + if args.encode_only: + embed(args.model) + else: + generate(args.model) + diff --git a/t5/requirements.txt b/t5/requirements.txt new file mode 100644 index 00000000..4a37303a --- /dev/null +++ b/t5/requirements.txt @@ -0,0 +1,3 @@ +mlx +numpy +transformers diff --git a/t5/t5.py b/t5/t5.py new file mode 100644 index 00000000..6dc5835d --- /dev/null +++ b/t5/t5.py @@ -0,0 +1,469 @@ +import argparse +from typing import Optional, Tuple, List +from time import perf_counter_ns + +import numpy as np +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten, tree_map +from transformers import T5Config, T5Tokenizer + + +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): + """ + Adapted from HF Tensorflow: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets + relative_position = mx.abs(relative_position) + else: + relative_position = -mx.minimum( + relative_position, mx.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) + relative_position_if_large = max_exact + ( + mx.log(relative_position.astype(mx.float32) / max_exact) * scale + ).astype(mx.int16) + relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += mx.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding( + config.relative_attention_num_buckets, config.num_heads + ) + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = _relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + # Dimensions are [batch x num heads x sequence x hidden dim] + queries = queries + scores = queries @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.out_proj(values_hat), (keys, values) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _norm(self, x): + return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) + + def __call__(self, x): + t = x.dtype + output = self._norm(x).astype(t) + return self.weight * output + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.self_attention = MultiHeadAttention(config) + self.cross_attention = MultiHeadAttention(config) + self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__( + self, + x: mx.array, + memory: mx.array, + mask: mx.array, + memory_mask: mx.array, + cache: Optional[List[Tuple[mx.array, mx.array]]] = None, + ): + y = self.ln1(x) + y, cache = self.self_attention(y, y, y, mask, cache) + x = x + y + + y = self.ln2(x) + y, _ = self.cross_attention(y, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.dense(y) + x = x + y + + return x, cache + + +class TransformerDecoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerDecoderLayer(config) for i in range(config.num_layers) + ] + self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) + + def __call__(self, x, memory, mask, memory_mask, cache=None): + if cache is not None: + offset = cache[0][0].shape[3] + else: + offset = 0 + cache = [None] * len(self.layers) + + T = offset + x.shape[1] + pos_bias = self.relative_attention_bias(T, T, offset=offset) + if mask is not None: + mask += pos_bias + else: + mask = pos_bias + + for e, layer in enumerate(self.layers): + x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) + x = self.ln(x) + + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: T5Config): + self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) + + def __call__(self, inputs): + return self.linear(inputs) + + +class T5(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + self.decoder = TransformerDecoder(config) + self.tie_word_embeddings = config.tie_word_embeddings + if not self.tie_word_embeddings: + self.lm_head = OutputHead(config) + self.model_dim = config.d_model + + def encode(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) + + def decode( + self, + inputs: mx.array, + memory: mx.array, + cache=None, + ): + inputs = self.wte(inputs) + T = inputs.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(inputs.dtype) + else: + mask = None + + y, cache = self.decoder( + inputs, memory=memory, mask=mask, memory_mask=None, cache=cache + ) + if not self.tie_word_embeddings: + y *= self.model_dim**-0.5 + y = self.lm_head(y) + else: + y = y @ self.wte.weight.T + return y, cache + + def __call__( + self, + inputs: mx.array, + decoder_inputs: mx.array, + ): + return self.decode(decoder_inputs, self.encode(inputs))[0] + + +class Tokenizer: + def __init__(self, model_name: str, config: T5Config): + self._decoder_start_id = config.decoder_start_token_id + self._tokenizer = T5Tokenizer.from_pretrained( + args.model, + legacy=False, + model_max_length=config.n_positions, + ) + + @property + def eos_id(self) -> int: + return self._tokenizer.eos_token_id + + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + + def encode(self, s: str) -> mx.array: + return mx.array( + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + ) + + def decode(self, t: List[int], with_sep: bool = True) -> str: + tokens = self._tokenizer.convert_ids_to_tokens(t) + return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + + +def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + prompt = tokenizer.encode(prompt) + decoder_inputs = mx.array([tokenizer.decoder_start_id]) + memory = model.encode(prompt) + cache = None + y = decoder_inputs + while True: + logits, cache = model.decode(y[None], memory, cache=cache) + y = sample(logits[:, -1, :]) + yield y.squeeze() + + +def load_model(model_name: str, dtype: str = "float16"): + config = T5Config.from_pretrained(args.model) + dtype = getattr(mx, dtype) + model = T5(config) + file_name = model_name.replace("/", "-") + weights = mx.load(f"{file_name}.npz") + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model.update(weights) + mx.eval(model.parameters()) + return model, Tokenizer(args.model, config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="T5 Inference script") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + default="t5-small", + ) + parser.add_argument( + "--prompt", + help="", + default="translate English to German: That is good.", + ) + parser.add_argument( + "--encode-only", + action="store_true", + default=False, + help="Whether to decode or not. If true, will output last layer of encoder.", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float32", + ) + + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.model, args.dtype) + + if args.encode_only: + print("[INFO] Encoding with T5...", flush=True) + print(args.prompt, flush=True) + encoder_output = model.encode(tokenizer.encode(args.prompt)) + print(encoder_output, flush=True) + exit(0) + + print("[INFO] Generating with T5...", flush=True) + print("Input: ", args.prompt, flush=True) + + start = perf_counter_ns() + for token, n_tokens in zip( + generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens) + ): + if token.item() == tokenizer.eos_id: + break + print( + tokenizer.decode([token.item()], with_sep=n_tokens > 0), + end="", + flush=True, + ) + + n_tokens += 1 + end = perf_counter_ns() + elapsed = (end - start) / 1.0e9 + print() + print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")