Support for multiple EOS tokens (#1141)

* Support for multiple EOS tokens

* Change _eos_token_ids type from list to set

* Remove model_config & add eos_token_id

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid 2024-12-10 00:53:58 +08:00 committed by GitHub
parent 5687d5b99b
commit 12083c4b7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 17 deletions

View File

@ -254,21 +254,33 @@ class TokenizerWrapper:
huggingface tokenizer. huggingface tokenizer.
""" """
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): def __init__(
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer) self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = (
set(eos_token_ids)
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer
elif attr == "eos_token_ids":
return self._eos_token_ids
elif attr.startswith("_"): elif attr.startswith("_"):
return self.__getattribute__(attr) return self.__getattribute__(attr)
else: else:
return getattr(self._tokenizer, attr) return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr in {"detokenizer", "eos_token_ids"}:
if attr == "detokenizer": if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.") raise AttributeError("Cannot set the detokenizer.")
elif attr == "eos_token_ids":
self._eos_token_ids = set(value) if value is not None else set()
elif attr.startswith("_"): elif attr.startswith("_"):
super().__setattr__(attr, value) super().__setattr__(attr, value)
else: else:
@ -315,7 +327,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}): def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming """Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use. detokenizer to use.
@ -336,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
elif _is_bpe_decoder(tokenizer_content["decoder"]): elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer detokenizer_class = BPEStreamingDetokenizer
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return TokenizerWrapper( return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids,
) )

View File

@ -361,7 +361,7 @@ def stream_generate(
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
tic = time.perf_counter() tic = time.perf_counter()
if token == tokenizer.eos_token_id: if token in tokenizer.eos_token_ids:
break break
detokenizer.add_token(token) detokenizer.add_token(token)
@ -467,11 +467,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model. model_config (dict, optional): Optional configuration parameters for the
Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config. A function that returns the model class and model args class given a config.
Defaults to the _get_classes function. Defaults to the ``_get_classes`` function.
Returns: Returns:
nn.Module: The loaded and initialized model. nn.Module: The loaded and initialized model.
@ -480,7 +480,6 @@ def load_model(
FileNotFoundError: If the weight files (.safetensors) are not found. FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated. ValueError: If the model class or args class are not found or cannot be instantiated.
""" """
config = load_config(model_path) config = load_config(model_path)
config.update(model_config) config.update(model_config)
@ -530,7 +529,7 @@ def load_model(
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
return model return model, config
def load( def load(
@ -563,11 +562,13 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config) model, config = load_model(model_path, lazy)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config) tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
)
return model, tokenizer return model, tokenizer
@ -575,9 +576,10 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy) model, config = load_model(model_path, lazy)
config = load_config(model_path) tokenizer = load_tokenizer(
tokenizer = load_tokenizer(model_path) model_path, eos_token_ids=config.get("eos_token_id", None)
)
return model, config, tokenizer return model, config, tokenizer

View File

@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
return CustomQwenModel, CustomQwenConfig return CustomQwenModel, CustomQwenConfig
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path, get_model_classes=custom_get_classes) model, _ = load_model(model_path, get_model_classes=custom_get_classes)
self.assertIsInstance(model, CustomQwenModel) self.assertIsInstance(model, CustomQwenModel)
self.assertTrue(hasattr(model, "custom_attribute")) self.assertTrue(hasattr(model, "custom_attribute"))
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
def test_load_model_with_default_get_classes(self): def test_load_model_with_default_get_classes(self):
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path) model, _ = load_model(model_path)
self.assertIsInstance(model, Qwen2Model) self.assertIsInstance(model, Qwen2Model)