fp16, abstract tokenizer a bit, format

This commit is contained in:
Awni Hannun 2023-12-18 13:15:02 -08:00
parent 72581e5c1a
commit fd351850e4
2 changed files with 66 additions and 60 deletions

View File

@ -46,11 +46,12 @@ def replace_key(key: str) -> str:
return key return key
def convert(model_name, half_precision=False): def convert(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} weights = {
if half_precision: replace_key(k): v.numpy().astype(np.float16)
weights = {k: v.astype(np.float16) for k, v in weights.items()} for k, v in model.state_dict().items()
}
np.savez(f"{model_name}.npz", **weights) np.savez(f"{model_name}.npz", **weights)
@ -65,10 +66,5 @@ if __name__ == "__main__":
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
default="t5-small", default="t5-small",
) )
parser.add_argument(
"--half-precision",
action="store_true",
help="Convert weights to half precision (float16).",
)
args = parser.parse_args() args = parser.parse_args()
convert(args.model, args.half_precision) convert(args.model)

108
t5/t5.py
View File

@ -5,7 +5,7 @@ from time import perf_counter_ns
import numpy as np import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_unflatten, tree_map
from transformers import T5Config, T5Tokenizer from transformers import T5Config, T5Tokenizer
@ -129,7 +129,7 @@ class MultiHeadAttention(nn.Module):
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
@ -291,8 +291,6 @@ class T5(nn.Module):
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
) )
if self.tie_word_embeddings: if self.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
y *= self.model_dim**-0.5 y *= self.model_dim**-0.5
return self.lm_head(y), cache return self.lm_head(y), cache
@ -304,16 +302,47 @@ class T5(nn.Module):
return self.decode(decoder_inputs, self.encode(inputs))[0] return self.decode(decoder_inputs, self.encode(inputs))[0]
def generate( class Tokenizer:
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0 def __init__(self, model_name: str, config: T5Config):
): self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=config.n_positions,
)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
def sample(logits): def sample(logits):
if temp == 0: if temp == 0:
return mx.argmax(logits, axis=-1) return mx.argmax(logits, axis=-1)
else: else:
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
memory = model.encode(inputs) prompt = tokenizer.encode(prompt)
decoder_inputs = mx.array([tokenizer.decoder_start_id])
memory = model.encode(prompt)
cache = None cache = None
y = decoder_inputs y = decoder_inputs
while True: while True:
@ -322,26 +351,16 @@ def generate(
yield y.squeeze() yield y.squeeze()
def load_model(model_name: str, config: T5Config): def load_model(model_name: str, dtype: str = "float16"):
config = T5Config.from_pretrained(args.model)
dtype = getattr(mx, dtype)
model = T5(config) model = T5(config)
weights = mx.load(f"{model_name}.npz") weights = mx.load(f"{model_name}.npz")
current_weights = tree_flatten(model.parameters()) weights = tree_unflatten(list(weights.items()))
weights_to_load = list(weights.items()) weights = tree_map(lambda p: p.astype(dtype), weights)
current_weights_dict = dict(current_weights) model.update(weights)
current_weights_keys = set(current_weights_dict.keys())
weights_to_load_dict = dict(weights_to_load)
weights_to_load_keys = set(weights_to_load_dict.keys())
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
print()
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
for key in current_weights_keys & weights_to_load_keys:
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
print("Shape mismatch for key: ", key)
print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters()) mx.eval(model.parameters())
return model return model, Tokenizer(args.model, config)
if __name__ == "__main__": if __name__ == "__main__":
@ -365,7 +384,7 @@ if __name__ == "__main__":
help="Whether to decode or not. If true, will output last layer of encoder.", help="Whether to decode or not. If true, will output last layer of encoder.",
) )
parser.add_argument( parser.add_argument(
"--max_tokens", "--max-tokens",
"-m", "-m",
type=int, type=int,
default=100, default=100,
@ -377,53 +396,44 @@ if __name__ == "__main__":
type=float, type=float,
default=0.0, default=0.0,
) )
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) mx.random.seed(args.seed)
config = T5Config.from_pretrained(args.model) model, tokenizer = load_model(args.model)
model = load_model(args.model, config)
tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=config.n_positions,
)
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
if args.encode_only: if args.encode_only:
print("[INFO] Encoding with T5...", flush=True) print("[INFO] Encoding with T5...", flush=True)
print(args.prompt, flush=True) print(args.prompt, flush=True)
encoder_output = model.encode(prompt) encoder_output = model.encode(tokenizer.encode(args.prompt))
print(encoder_output, flush=True) print(encoder_output, flush=True)
exit(0) exit(0)
print("[INFO] Generating with T5...", flush=True) print("[INFO] Generating with T5...", flush=True)
print("Input: ", args.prompt, flush=True) print("Input: ", args.prompt, flush=True)
decoder_inputs = mx.array([config.decoder_start_token_id])
start = perf_counter_ns() start = perf_counter_ns()
tokens = []
for token, n_tokens in zip( for token, n_tokens in zip(
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens) generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
): ):
if token.item() == tokenizer.eos_token_id: if token.item() == tokenizer.eos_id:
break break
print( print(
tokenizer.convert_ids_to_tokens(token.item()).replace("", " "), tokenizer.decode([token.item()], with_sep=n_tokens > 0),
end="", end="",
flush=True, flush=True,
) )
n_tokens += 1
end = perf_counter_ns() end = perf_counter_ns()
elapsed = (end - start) / 1.0e9 elapsed = (end - start) / 1.0e9
print() print()