mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 05:28:11 +08:00 
			
		
		
		
	Support for multiple EOS tokens (#1141)
* Support for multiple EOS tokens * Change _eos_token_ids type from list to set * Remove model_config & add eos_token_id * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -254,21 +254,33 @@ class TokenizerWrapper:
 | 
			
		||||
    huggingface tokenizer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
 | 
			
		||||
    ):
 | 
			
		||||
        self._tokenizer = tokenizer
 | 
			
		||||
        self._detokenizer = detokenizer_class(tokenizer)
 | 
			
		||||
        self._eos_token_ids = (
 | 
			
		||||
            set(eos_token_ids)
 | 
			
		||||
            if eos_token_ids is not None
 | 
			
		||||
            else {tokenizer.eos_token_id}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, attr):
 | 
			
		||||
        if attr == "detokenizer":
 | 
			
		||||
            return self._detokenizer
 | 
			
		||||
        elif attr == "eos_token_ids":
 | 
			
		||||
            return self._eos_token_ids
 | 
			
		||||
        elif attr.startswith("_"):
 | 
			
		||||
            return self.__getattribute__(attr)
 | 
			
		||||
        else:
 | 
			
		||||
            return getattr(self._tokenizer, attr)
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, attr, value):
 | 
			
		||||
        if attr == "detokenizer":
 | 
			
		||||
            raise AttributeError("Cannot set the detokenizer.")
 | 
			
		||||
        if attr in {"detokenizer", "eos_token_ids"}:
 | 
			
		||||
            if attr == "detokenizer":
 | 
			
		||||
                raise AttributeError("Cannot set the detokenizer.")
 | 
			
		||||
            elif attr == "eos_token_ids":
 | 
			
		||||
                self._eos_token_ids = set(value) if value is not None else set()
 | 
			
		||||
        elif attr.startswith("_"):
 | 
			
		||||
            super().__setattr__(attr, value)
 | 
			
		||||
        else:
 | 
			
		||||
@@ -315,7 +327,7 @@ def _is_bpe_decoder(decoder):
 | 
			
		||||
    return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
 | 
			
		||||
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
 | 
			
		||||
    """Load a huggingface tokenizer and try to infer the type of streaming
 | 
			
		||||
    detokenizer to use.
 | 
			
		||||
 | 
			
		||||
@@ -336,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
 | 
			
		||||
            elif _is_bpe_decoder(tokenizer_content["decoder"]):
 | 
			
		||||
                detokenizer_class = BPEStreamingDetokenizer
 | 
			
		||||
 | 
			
		||||
    if isinstance(eos_token_ids, int):
 | 
			
		||||
        eos_token_ids = [eos_token_ids]
 | 
			
		||||
    return TokenizerWrapper(
 | 
			
		||||
        AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
 | 
			
		||||
        detokenizer_class,
 | 
			
		||||
        eos_token_ids=eos_token_ids,
 | 
			
		||||
    )
 | 
			
		||||
 
 | 
			
		||||
@@ -361,7 +361,7 @@ def stream_generate(
 | 
			
		||||
                prompt_time = time.perf_counter() - tic
 | 
			
		||||
                prompt_tps = prompt.size / prompt_time
 | 
			
		||||
                tic = time.perf_counter()
 | 
			
		||||
            if token == tokenizer.eos_token_id:
 | 
			
		||||
            if token in tokenizer.eos_token_ids:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            detokenizer.add_token(token)
 | 
			
		||||
@@ -467,11 +467,11 @@ def load_model(
 | 
			
		||||
        lazy (bool): If False eval the model parameters to make sure they are
 | 
			
		||||
            loaded in memory before returning, otherwise they will be loaded
 | 
			
		||||
            when needed. Default: ``False``
 | 
			
		||||
        model_config (dict, optional): Configuration parameters for the model.
 | 
			
		||||
            Defaults to an empty dictionary.
 | 
			
		||||
        model_config (dict, optional): Optional configuration parameters for the
 | 
			
		||||
            model. Defaults to an empty dictionary.
 | 
			
		||||
        get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
 | 
			
		||||
            A function that returns the model class and model args class given a config.
 | 
			
		||||
            Defaults to the _get_classes function.
 | 
			
		||||
            Defaults to the ``_get_classes`` function.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        nn.Module: The loaded and initialized model.
 | 
			
		||||
@@ -480,7 +480,6 @@ def load_model(
 | 
			
		||||
        FileNotFoundError: If the weight files (.safetensors) are not found.
 | 
			
		||||
        ValueError: If the model class or args class are not found or cannot be instantiated.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    config = load_config(model_path)
 | 
			
		||||
    config.update(model_config)
 | 
			
		||||
 | 
			
		||||
@@ -530,7 +529,7 @@ def load_model(
 | 
			
		||||
        mx.eval(model.parameters())
 | 
			
		||||
 | 
			
		||||
    model.eval()
 | 
			
		||||
    return model
 | 
			
		||||
    return model, config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load(
 | 
			
		||||
@@ -563,11 +562,13 @@ def load(
 | 
			
		||||
    """
 | 
			
		||||
    model_path = get_model_path(path_or_hf_repo)
 | 
			
		||||
 | 
			
		||||
    model = load_model(model_path, lazy, model_config)
 | 
			
		||||
    model, config = load_model(model_path, lazy)
 | 
			
		||||
    if adapter_path is not None:
 | 
			
		||||
        model = load_adapters(model, adapter_path)
 | 
			
		||||
        model.eval()
 | 
			
		||||
    tokenizer = load_tokenizer(model_path, tokenizer_config)
 | 
			
		||||
    tokenizer = load_tokenizer(
 | 
			
		||||
        model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
@@ -575,9 +576,10 @@ def load(
 | 
			
		||||
def fetch_from_hub(
 | 
			
		||||
    model_path: Path, lazy: bool = False
 | 
			
		||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
 | 
			
		||||
    model = load_model(model_path, lazy)
 | 
			
		||||
    config = load_config(model_path)
 | 
			
		||||
    tokenizer = load_tokenizer(model_path)
 | 
			
		||||
    model, config = load_model(model_path, lazy)
 | 
			
		||||
    tokenizer = load_tokenizer(
 | 
			
		||||
        model_path, eos_token_ids=config.get("eos_token_id", None)
 | 
			
		||||
    )
 | 
			
		||||
    return model, config, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user