mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Stream output
This commit is contained in:
parent
689eda9937
commit
09e851499a
14
t5/t5.py
14
t5/t5.py
@ -1,14 +1,13 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional
|
||||
from time import perf_counter_ns
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import T5Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -354,7 +353,7 @@ def load_model(model_config):
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@ -421,10 +420,11 @@ if __name__ == "__main__":
|
||||
):
|
||||
if token.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
tokens.append(token.item())
|
||||
# For some reason using the following line doesn't give spaces
|
||||
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
||||
print(tokenizer.decode(tokens), end="", flush=True)
|
||||
print(
|
||||
tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
end = perf_counter_ns()
|
||||
elapsed = (end - start) / 1.0e9
|
||||
|
Loading…
Reference in New Issue
Block a user