Stream output

This commit is contained in:
Juarez Bochi 2023-12-18 08:09:56 -05:00
parent 689eda9937
commit 09e851499a
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,14 +1,13 @@
import argparse
from dataclasses import dataclass
from typing import Optional, Tuple, List
from typing import Optional
from time import perf_counter_ns
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
from transformers import AutoTokenizer
from transformers import T5Tokenizer
@dataclass
@ -354,7 +353,7 @@ def load_model(model_config):
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer
@ -421,10 +420,11 @@ if __name__ == "__main__":
):
if token.item() == tokenizer.eos_token_id:
break
tokens.append(token.item())
# For some reason using the following line doesn't give spaces
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
print(tokenizer.decode(tokens), end="", flush=True)
print(
tokenizer.convert_ids_to_tokens(token.item()).replace("", " "),
end="",
flush=True,
)
end = perf_counter_ns()
elapsed = (end - start) / 1.0e9