Stream output

This commit is contained in:
Juarez Bochi 2023-12-18 08:09:56 -05:00
parent 689eda9937
commit 09e851499a
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,14 +1,13 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from typing import Optional
from time import perf_counter_ns from time import perf_counter_ns
import numpy as np import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from transformers import AutoTokenizer from transformers import T5Tokenizer
@dataclass @dataclass
@ -354,7 +353,7 @@ def load_model(model_config):
print("Loading shape: ", weights_to_load_dict[key].shape) print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load)) model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters()) mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer return model, tokenizer
@ -421,10 +420,11 @@ if __name__ == "__main__":
): ):
if token.item() == tokenizer.eos_token_id: if token.item() == tokenizer.eos_token_id:
break break
tokens.append(token.item()) print(
# For some reason using the following line doesn't give spaces tokenizer.convert_ids_to_tokens(token.item()).replace("", " "),
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) end="",
print(tokenizer.decode(tokens), end="", flush=True) flush=True,
)
end = perf_counter_ns() end = perf_counter_ns()
elapsed = (end - start) / 1.0e9 elapsed = (end - start) / 1.0e9