mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Support for multiple EOS tokens (#1141)
* Support for multiple EOS tokens * Change _eos_token_ids type from list to set * Remove model_config & add eos_token_id * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
5687d5b99b
commit
12083c4b7e
@ -254,21 +254,33 @@ class TokenizerWrapper:
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
|
||||
def __init__(
|
||||
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
|
||||
):
|
||||
self._tokenizer = tokenizer
|
||||
self._detokenizer = detokenizer_class(tokenizer)
|
||||
self._eos_token_ids = (
|
||||
set(eos_token_ids)
|
||||
if eos_token_ids is not None
|
||||
else {tokenizer.eos_token_id}
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr == "detokenizer":
|
||||
return self._detokenizer
|
||||
elif attr == "eos_token_ids":
|
||||
return self._eos_token_ids
|
||||
elif attr.startswith("_"):
|
||||
return self.__getattribute__(attr)
|
||||
else:
|
||||
return getattr(self._tokenizer, attr)
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr == "detokenizer":
|
||||
raise AttributeError("Cannot set the detokenizer.")
|
||||
if attr in {"detokenizer", "eos_token_ids"}:
|
||||
if attr == "detokenizer":
|
||||
raise AttributeError("Cannot set the detokenizer.")
|
||||
elif attr == "eos_token_ids":
|
||||
self._eos_token_ids = set(value) if value is not None else set()
|
||||
elif attr.startswith("_"):
|
||||
super().__setattr__(attr, value)
|
||||
else:
|
||||
@ -315,7 +327,7 @@ def _is_bpe_decoder(decoder):
|
||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||
|
||||
|
||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@ -336,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||
detokenizer_class = BPEStreamingDetokenizer
|
||||
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
return TokenizerWrapper(
|
||||
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||
detokenizer_class,
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
@ -361,7 +361,7 @@ def stream_generate(
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
if token in tokenizer.eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
@ -467,11 +467,11 @@ def load_model(
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
model_config (dict, optional): Configuration parameters for the model.
|
||||
Defaults to an empty dictionary.
|
||||
model_config (dict, optional): Optional configuration parameters for the
|
||||
model. Defaults to an empty dictionary.
|
||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||
A function that returns the model class and model args class given a config.
|
||||
Defaults to the _get_classes function.
|
||||
Defaults to the ``_get_classes`` function.
|
||||
|
||||
Returns:
|
||||
nn.Module: The loaded and initialized model.
|
||||
@ -480,7 +480,6 @@ def load_model(
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
|
||||
config = load_config(model_path)
|
||||
config.update(model_config)
|
||||
|
||||
@ -530,7 +529,7 @@ def load_model(
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
return model, config
|
||||
|
||||
|
||||
def load(
|
||||
@ -563,11 +562,13 @@ def load(
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path, lazy, model_config)
|
||||
model, config = load_model(model_path, lazy)
|
||||
if adapter_path is not None:
|
||||
model = load_adapters(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
tokenizer = load_tokenizer(
|
||||
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@ -575,9 +576,10 @@ def load(
|
||||
def fetch_from_hub(
|
||||
model_path: Path, lazy: bool = False
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path, lazy)
|
||||
config = load_config(model_path)
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model, config = load_model(model_path, lazy)
|
||||
tokenizer = load_tokenizer(
|
||||
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||
)
|
||||
return model, config, tokenizer
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||
return CustomQwenModel, CustomQwenConfig
|
||||
|
||||
model_path = get_model_path(HF_MODEL_PATH)
|
||||
model = load_model(model_path, get_model_classes=custom_get_classes)
|
||||
model, _ = load_model(model_path, get_model_classes=custom_get_classes)
|
||||
|
||||
self.assertIsInstance(model, CustomQwenModel)
|
||||
self.assertTrue(hasattr(model, "custom_attribute"))
|
||||
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||
|
||||
def test_load_model_with_default_get_classes(self):
|
||||
model_path = get_model_path(HF_MODEL_PATH)
|
||||
model = load_model(model_path)
|
||||
model, _ = load_model(model_path)
|
||||
|
||||
self.assertIsInstance(model, Qwen2Model)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user