mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Stream output
This commit is contained in:
parent
689eda9937
commit
09e851499a
14
t5/t5.py
14
t5/t5.py
@ -1,14 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from typing import Optional
|
|
||||||
from time import perf_counter_ns
|
from time import perf_counter_ns
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
from transformers import AutoTokenizer
|
from transformers import T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -354,7 +353,7 @@ def load_model(model_config):
|
|||||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||||
model.update(tree_unflatten(weights_to_load))
|
model.update(tree_unflatten(weights_to_load))
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -421,10 +420,11 @@ if __name__ == "__main__":
|
|||||||
):
|
):
|
||||||
if token.item() == tokenizer.eos_token_id:
|
if token.item() == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
tokens.append(token.item())
|
print(
|
||||||
# For some reason using the following line doesn't give spaces
|
tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "),
|
||||||
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
end="",
|
||||||
print(tokenizer.decode(tokens), end="", flush=True)
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
end = perf_counter_ns()
|
end = perf_counter_ns()
|
||||||
elapsed = (end - start) / 1.0e9
|
elapsed = (end - start) / 1.0e9
|
||||||
|
Loading…
Reference in New Issue
Block a user