Support for multiple EOS tokens (#1141)

* Support for multiple EOS tokens

* Change _eos_token_ids type from list to set

* Remove model_config & add eos_token_id

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-12-10 00:53:58 +08:00
committed by GitHub
parent 5687d5b99b
commit 12083c4b7e
3 changed files with 34 additions and 17 deletions

View File

@@ -361,7 +361,7 @@ def stream_generate(
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
@@ -467,11 +467,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the _get_classes function.
Defaults to the ``_get_classes`` function.
Returns:
nn.Module: The loaded and initialized model.
@@ -480,7 +480,6 @@ def load_model(
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
config = load_config(model_path)
config.update(model_config)
@@ -530,7 +529,7 @@ def load_model(
mx.eval(model.parameters())
model.eval()
return model
return model, config
def load(
@@ -563,11 +562,13 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config)
model, config = load_model(model_path, lazy)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
)
return model, tokenizer
@@ -575,9 +576,10 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path)
model, config = load_model(model_path, lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)
return model, config, tokenizer