mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
fp16, abstract tokenizer a bit, format
This commit is contained in:
parent
72581e5c1a
commit
fd351850e4
@ -46,11 +46,12 @@ def replace_key(key: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_name, half_precision=False):
|
||||
def convert(model_name):
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||
if half_precision:
|
||||
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(np.float16)
|
||||
for k, v in model.state_dict().items()
|
||||
}
|
||||
np.savez(f"{model_name}.npz", **weights)
|
||||
|
||||
|
||||
@ -65,10 +66,5 @@ if __name__ == "__main__":
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--half-precision",
|
||||
action="store_true",
|
||||
help="Convert weights to half precision (float16).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model, args.half_precision)
|
||||
convert(args.model)
|
||||
|
110
t5/t5.py
110
t5/t5.py
@ -5,7 +5,7 @@ from time import perf_counter_ns
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from mlx.utils import tree_unflatten, tree_map
|
||||
from transformers import T5Config, T5Tokenizer
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ class MultiHeadAttention(nn.Module):
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
@ -291,9 +291,7 @@ class T5(nn.Module):
|
||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||
)
|
||||
if self.tie_word_embeddings:
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
|
||||
y *= self.model_dim ** -0.5
|
||||
y *= self.model_dim**-0.5
|
||||
return self.lm_head(y), cache
|
||||
|
||||
def __call__(
|
||||
@ -304,16 +302,47 @@ class T5(nn.Module):
|
||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||
|
||||
|
||||
def generate(
|
||||
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
|
||||
):
|
||||
class Tokenizer:
|
||||
def __init__(self, model_name: str, config: T5Config):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=config.n_positions,
|
||||
)
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
|
||||
|
||||
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
memory = model.encode(inputs)
|
||||
prompt = tokenizer.encode(prompt)
|
||||
decoder_inputs = mx.array([tokenizer.decoder_start_id])
|
||||
memory = model.encode(prompt)
|
||||
cache = None
|
||||
y = decoder_inputs
|
||||
while True:
|
||||
@ -322,26 +351,16 @@ def generate(
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_name: str, config: T5Config):
|
||||
def load_model(model_name: str, dtype: str = "float16"):
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
dtype = getattr(mx, dtype)
|
||||
model = T5(config)
|
||||
weights = mx.load(f"{model_name}.npz")
|
||||
current_weights = tree_flatten(model.parameters())
|
||||
weights_to_load = list(weights.items())
|
||||
current_weights_dict = dict(current_weights)
|
||||
current_weights_keys = set(current_weights_dict.keys())
|
||||
weights_to_load_dict = dict(weights_to_load)
|
||||
weights_to_load_keys = set(weights_to_load_dict.keys())
|
||||
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
|
||||
print()
|
||||
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
|
||||
for key in current_weights_keys & weights_to_load_keys:
|
||||
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
|
||||
print("Shape mismatch for key: ", key)
|
||||
print("Expected shape: ", current_weights_dict[key].shape)
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
return model, Tokenizer(args.model, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -365,7 +384,7 @@ if __name__ == "__main__":
|
||||
help="Whether to decode or not. If true, will output last layer of encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
@ -377,53 +396,44 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The model data type.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
model = load_model(args.model, config)
|
||||
tokenizer = T5Tokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=config.n_positions,
|
||||
)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
model, tokenizer = load_model(args.model)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
print(args.prompt, flush=True)
|
||||
encoder_output = model.encode(prompt)
|
||||
encoder_output = model.encode(tokenizer.encode(args.prompt))
|
||||
print(encoder_output, flush=True)
|
||||
exit(0)
|
||||
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print("Input: ", args.prompt, flush=True)
|
||||
|
||||
decoder_inputs = mx.array([config.decoder_start_token_id])
|
||||
|
||||
start = perf_counter_ns()
|
||||
|
||||
tokens = []
|
||||
for token, n_tokens in zip(
|
||||
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
|
||||
generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
|
||||
):
|
||||
if token.item() == tokenizer.eos_token_id:
|
||||
if token.item() == tokenizer.eos_id:
|
||||
break
|
||||
print(
|
||||
tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "),
|
||||
tokenizer.decode([token.item()], with_sep=n_tokens > 0),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
n_tokens += 1
|
||||
end = perf_counter_ns()
|
||||
elapsed = (end - start) / 1.0e9
|
||||
print()
|
||||
|
Loading…
Reference in New Issue
Block a user