mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fp16, abstract tokenizer a bit, format
This commit is contained in:
parent
72581e5c1a
commit
fd351850e4
@ -46,11 +46,12 @@ def replace_key(key: str) -> str:
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def convert(model_name, half_precision=False):
|
def convert(model_name):
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
weights = {
|
||||||
if half_precision:
|
replace_key(k): v.numpy().astype(np.float16)
|
||||||
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
for k, v in model.state_dict().items()
|
||||||
|
}
|
||||||
np.savez(f"{model_name}.npz", **weights)
|
np.savez(f"{model_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
@ -65,10 +66,5 @@ if __name__ == "__main__":
|
|||||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--half-precision",
|
|
||||||
action="store_true",
|
|
||||||
help="Convert weights to half precision (float16).",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.model, args.half_precision)
|
convert(args.model)
|
||||||
|
108
t5/t5.py
108
t5/t5.py
@ -5,7 +5,7 @@ from time import perf_counter_ns
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_unflatten, tree_map
|
||||||
from transformers import T5Config, T5Tokenizer
|
from transformers import T5Config, T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(values_hat), (keys, values)
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
@ -291,8 +291,6 @@ class T5(nn.Module):
|
|||||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
)
|
)
|
||||||
if self.tie_word_embeddings:
|
if self.tie_word_embeddings:
|
||||||
# Rescale output before projecting on vocab
|
|
||||||
# See https://github.com/huggingface/transformers/blob/71d47f0ad498b7649f11d3a9cca3cd3585e4341f/src/transformers/models/t5/modeling_t5.py#L1766C9-L1769C71
|
|
||||||
y *= self.model_dim**-0.5
|
y *= self.model_dim**-0.5
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
@ -304,16 +302,47 @@ class T5(nn.Module):
|
|||||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
class Tokenizer:
|
||||||
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
|
def __init__(self, model_name: str, config: T5Config):
|
||||||
):
|
self._decoder_start_id = config.decoder_start_token_id
|
||||||
|
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
legacy=False,
|
||||||
|
model_max_length=config.n_positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
return self._tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoder_start_id(self) -> int:
|
||||||
|
return self._decoder_start_id
|
||||||
|
|
||||||
|
def encode(self, s: str) -> mx.array:
|
||||||
|
return mx.array(
|
||||||
|
self._tokenizer(
|
||||||
|
s,
|
||||||
|
return_tensors="np",
|
||||||
|
return_attention_mask=False,
|
||||||
|
)["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||||
|
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||||
|
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||||
def sample(logits):
|
def sample(logits):
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
return mx.argmax(logits, axis=-1)
|
return mx.argmax(logits, axis=-1)
|
||||||
else:
|
else:
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
memory = model.encode(inputs)
|
prompt = tokenizer.encode(prompt)
|
||||||
|
decoder_inputs = mx.array([tokenizer.decoder_start_id])
|
||||||
|
memory = model.encode(prompt)
|
||||||
cache = None
|
cache = None
|
||||||
y = decoder_inputs
|
y = decoder_inputs
|
||||||
while True:
|
while True:
|
||||||
@ -322,26 +351,16 @@ def generate(
|
|||||||
yield y.squeeze()
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str, config: T5Config):
|
def load_model(model_name: str, dtype: str = "float16"):
|
||||||
|
config = T5Config.from_pretrained(args.model)
|
||||||
|
dtype = getattr(mx, dtype)
|
||||||
model = T5(config)
|
model = T5(config)
|
||||||
weights = mx.load(f"{model_name}.npz")
|
weights = mx.load(f"{model_name}.npz")
|
||||||
current_weights = tree_flatten(model.parameters())
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights_to_load = list(weights.items())
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
current_weights_dict = dict(current_weights)
|
model.update(weights)
|
||||||
current_weights_keys = set(current_weights_dict.keys())
|
|
||||||
weights_to_load_dict = dict(weights_to_load)
|
|
||||||
weights_to_load_keys = set(weights_to_load_dict.keys())
|
|
||||||
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
|
|
||||||
print()
|
|
||||||
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
|
|
||||||
for key in current_weights_keys & weights_to_load_keys:
|
|
||||||
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
|
|
||||||
print("Shape mismatch for key: ", key)
|
|
||||||
print("Expected shape: ", current_weights_dict[key].shape)
|
|
||||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
|
||||||
model.update(tree_unflatten(weights_to_load))
|
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
return model, Tokenizer(args.model, config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -365,7 +384,7 @@ if __name__ == "__main__":
|
|||||||
help="Whether to decode or not. If true, will output last layer of encoder.",
|
help="Whether to decode or not. If true, will output last layer of encoder.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=100,
|
||||||
@ -377,53 +396,44 @@ if __name__ == "__main__":
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="The model data type.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "bfloat16", "float32"],
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
config = T5Config.from_pretrained(args.model)
|
model, tokenizer = load_model(args.model)
|
||||||
model = load_model(args.model, config)
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained(
|
|
||||||
args.model,
|
|
||||||
legacy=False,
|
|
||||||
model_max_length=config.n_positions,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = tokenizer(
|
|
||||||
args.prompt,
|
|
||||||
return_tensors="np",
|
|
||||||
return_attention_mask=False,
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
prompt = mx.array(prompt)
|
|
||||||
|
|
||||||
if args.encode_only:
|
if args.encode_only:
|
||||||
print("[INFO] Encoding with T5...", flush=True)
|
print("[INFO] Encoding with T5...", flush=True)
|
||||||
print(args.prompt, flush=True)
|
print(args.prompt, flush=True)
|
||||||
encoder_output = model.encode(prompt)
|
encoder_output = model.encode(tokenizer.encode(args.prompt))
|
||||||
print(encoder_output, flush=True)
|
print(encoder_output, flush=True)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
print("[INFO] Generating with T5...", flush=True)
|
print("[INFO] Generating with T5...", flush=True)
|
||||||
print("Input: ", args.prompt, flush=True)
|
print("Input: ", args.prompt, flush=True)
|
||||||
|
|
||||||
decoder_inputs = mx.array([config.decoder_start_token_id])
|
|
||||||
|
|
||||||
start = perf_counter_ns()
|
start = perf_counter_ns()
|
||||||
|
|
||||||
tokens = []
|
|
||||||
for token, n_tokens in zip(
|
for token, n_tokens in zip(
|
||||||
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
|
generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
|
||||||
):
|
):
|
||||||
if token.item() == tokenizer.eos_token_id:
|
if token.item() == tokenizer.eos_id:
|
||||||
break
|
break
|
||||||
print(
|
print(
|
||||||
tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "),
|
tokenizer.decode([token.item()], with_sep=n_tokens > 0),
|
||||||
end="",
|
end="",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
n_tokens += 1
|
||||||
end = perf_counter_ns()
|
end = perf_counter_ns()
|
||||||
elapsed = (end - start) / 1.0e9
|
elapsed = (end - start) / 1.0e9
|
||||||
print()
|
print()
|
||||||
|
Loading…
Reference in New Issue
Block a user