diff --git a/t5/t5.py b/t5/t5.py index c137e0c1..8fd2936a 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,14 +1,13 @@ import argparse from dataclasses import dataclass from typing import Optional, Tuple, List -from typing import Optional from time import perf_counter_ns import numpy as np import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_unflatten -from transformers import AutoTokenizer +from transformers import T5Tokenizer @dataclass @@ -354,7 +353,7 @@ def load_model(model_config): print("Loading shape: ", weights_to_load_dict[key].shape) model.update(tree_unflatten(weights_to_load)) mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) + tokenizer = T5Tokenizer.from_pretrained("t5-small", trust_remote_code=True) return model, tokenizer @@ -421,10 +420,11 @@ if __name__ == "__main__": ): if token.item() == tokenizer.eos_token_id: break - tokens.append(token.item()) - # For some reason using the following line doesn't give spaces - # print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) - print(tokenizer.decode(tokens), end="", flush=True) + print( + tokenizer.convert_ids_to_tokens(token.item()).replace("▁", " "), + end="", + flush=True, + ) end = perf_counter_ns() elapsed = (end - start) / 1.0e9