From fd351850e46d49b851aef5a576995766d0efdf22 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Dec 2023 13:15:02 -0800 Subject: [PATCH] fp16, abstract tokenizer a bit, format --- t5/convert.py | 16 +++----- t5/t5.py | 110 +++++++++++++++++++++++++++----------------------- 2 files changed, 66 insertions(+), 60 deletions(-) diff --git a/t5/convert.py b/t5/convert.py index 54c7b76b..77d5cfd9 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -46,11 +46,12 @@ def replace_key(key: str) -> str: return key -def convert(model_name, half_precision=False): +def convert(model_name): model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") - weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} - if half_precision: - weights = {k: v.astype(np.float16) for k, v in weights.items()} + weights = { + replace_key(k): v.numpy().astype(np.float16) + for k, v in model.state_dict().items() + } np.savez(f"{model_name}.npz", **weights) @@ -65,10 +66,5 @@ if __name__ == "__main__": choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], default="t5-small", ) - parser.add_argument( - "--half-precision", - action="store_true", - help="Convert weights to half precision (float16).", - ) args = parser.parse_args() - convert(args.model, args.half_precision) + convert(args.model) diff --git a/t5/t5.py b/t5/t5.py index cd884c48..2d736fbe 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -5,7 +5,7 @@ from time import perf_counter_ns import numpy as np import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_unflatten, tree_map from transformers import T5Config, T5Tokenizer @@ -129,7 +129,7 @@ class MultiHeadAttention(nn.Module): if mask is not None: scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores, axis=-1) + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(values_hat), (keys, values) @@ -291,9 +291,7 @@ class T5(nn.Module): inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) if self.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71 - y *= self.model_dim ** -0.5 + y *= self.model_dim**-0.5 return self.lm_head(y), cache def __call__( @@ -304,16 +302,47 @@ class T5(nn.Module): return self.decode(decoder_inputs, self.encode(inputs))[0] -def generate( - inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0 -): +class Tokenizer: + def __init__(self, model_name: str, config: T5Config): + self._decoder_start_id = config.decoder_start_token_id + self._tokenizer = T5Tokenizer.from_pretrained( + args.model, + legacy=False, + model_max_length=config.n_positions, + ) + + @property + def eos_id(self) -> int: + return self._tokenizer.eos_token_id + + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + + def encode(self, s: str) -> mx.array: + return mx.array( + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + ) + + def decode(self, t: List[int], with_sep: bool = True) -> str: + tokens = self._tokenizer.convert_ids_to_tokens(t) + return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + + +def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0): def sample(logits): if temp == 0: return mx.argmax(logits, axis=-1) else: return mx.random.categorical(logits * (1 / temp)) - memory = model.encode(inputs) + prompt = tokenizer.encode(prompt) + decoder_inputs = mx.array([tokenizer.decoder_start_id]) + memory = model.encode(prompt) cache = None y = decoder_inputs while True: @@ -322,26 +351,16 @@ def generate( yield y.squeeze() -def load_model(model_name: str, config: T5Config): +def load_model(model_name: str, dtype: str = "float16"): + config = T5Config.from_pretrained(args.model) + dtype = getattr(mx, dtype) model = T5(config) weights = mx.load(f"{model_name}.npz") - current_weights = tree_flatten(model.parameters()) - weights_to_load = list(weights.items()) - current_weights_dict = dict(current_weights) - current_weights_keys = set(current_weights_dict.keys()) - weights_to_load_dict = dict(weights_to_load) - weights_to_load_keys = set(weights_to_load_dict.keys()) - print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys)) - print() - print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys)) - for key in current_weights_keys & weights_to_load_keys: - if weights_to_load_dict[key].shape != current_weights_dict[key].shape: - print("Shape mismatch for key: ", key) - print("Expected shape: ", current_weights_dict[key].shape) - print("Loading shape: ", weights_to_load_dict[key].shape) - model.update(tree_unflatten(weights_to_load)) + weights = tree_unflatten(list(weights.items())) + weights = tree_map(lambda p: p.astype(dtype), weights) + model.update(weights) mx.eval(model.parameters()) - return model + return model, Tokenizer(args.model, config) if __name__ == "__main__": @@ -365,7 +384,7 @@ if __name__ == "__main__": help="Whether to decode or not. If true, will output last layer of encoder.", ) parser.add_argument( - "--max_tokens", + "--max-tokens", "-m", type=int, default=100, @@ -377,53 +396,44 @@ if __name__ == "__main__": type=float, default=0.0, ) + parser.add_argument( + "--dtype", + help="The model data type.", + type=str, + choices=["float16", "bfloat16", "float32"], + default="float16", + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") args = parser.parse_args() mx.random.seed(args.seed) - config = T5Config.from_pretrained(args.model) - model = load_model(args.model, config) - tokenizer = T5Tokenizer.from_pretrained( - args.model, - legacy=False, - model_max_length=config.n_positions, - ) - - prompt = tokenizer( - args.prompt, - return_tensors="np", - return_attention_mask=False, - )["input_ids"] - - prompt = mx.array(prompt) + model, tokenizer = load_model(args.model) if args.encode_only: print("[INFO] Encoding with T5...", flush=True) print(args.prompt, flush=True) - encoder_output = model.encode(prompt) + encoder_output = model.encode(tokenizer.encode(args.prompt)) print(encoder_output, flush=True) exit(0) print("[INFO] Generating with T5...", flush=True) print("Input: ", args.prompt, flush=True) - decoder_inputs = mx.array([config.decoder_start_token_id]) - start = perf_counter_ns() - - tokens = [] for token, n_tokens in zip( - generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens) + generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens) ): - if token.item() == tokenizer.eos_token_id: + if token.item() == tokenizer.eos_id: break print( - tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "), + tokenizer.decode([token.item()], with_sep=n_tokens > 0), end="", flush=True, ) + n_tokens += 1 end = perf_counter_ns() elapsed = (end - start) / 1.0e9 print()