Support for multiple EOS tokens

This commit is contained in:
madroid 2024-12-07 11:19:20 +08:00
parent 1727959a27
commit f8379fb3ef
2 changed files with 28 additions and 15 deletions

View File

@ -254,21 +254,29 @@ class TokenizerWrapper:
huggingface tokenizer. huggingface tokenizer.
""" """
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): def __init__(
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer) self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = eos_token_ids or [tokenizer.eos_token_id]
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer
elif attr == "eos_token_ids":
return self._eos_token_ids
elif attr.startswith("_"): elif attr.startswith("_"):
return self.__getattribute__(attr) return self.__getattribute__(attr)
else: else:
return getattr(self._tokenizer, attr) return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr == "detokenizer": if attr in {"detokenizer", "eos_token_ids"}:
raise AttributeError("Cannot set the detokenizer.") if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.")
elif attr == "eos_token_ids":
self._eos_token_ids = value
elif attr.startswith("_"): elif attr.startswith("_"):
super().__setattr__(attr, value) super().__setattr__(attr, value)
else: else:
@ -315,7 +323,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}): def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
"""Load a huggingface tokenizer and try to infer the type of streaming """Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use. detokenizer to use.
@ -336,7 +344,11 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
elif _is_bpe_decoder(tokenizer_content["decoder"]): elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer detokenizer_class = BPEStreamingDetokenizer
eos_token_id = model_config["eos_token_id"]
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
return TokenizerWrapper( return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids,
) )

View File

@ -350,7 +350,7 @@ def stream_generate(
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
tic = time.perf_counter() tic = time.perf_counter()
if token == tokenizer.eos_token_id: if token in tokenizer.eos_token_ids:
break break
detokenizer.add_token(token) detokenizer.add_token(token)
@ -470,9 +470,6 @@ def load_model(
ValueError: If the model class or args class are not found or cannot be instantiated. ValueError: If the model class or args class are not found or cannot be instantiated.
""" """
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors")) weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files: if not weight_files:
@ -487,15 +484,15 @@ def load_model(
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
model_class, model_args_class = get_model_classes(config=config) model_class, model_args_class = get_model_classes(config=model_config)
model_args = model_args_class.from_dict(config) model_args = model_args_class.from_dict(model_config)
model = model_class(model_args) model = model_class(model_args)
if hasattr(model, "sanitize"): if hasattr(model, "sanitize"):
weights = model.sanitize(weights) weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None: if (quantization := model_config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized # Handle legacy models which may not have everything quantized
def class_predicate(p, m): def class_predicate(p, m):
if not hasattr(m, "to_quantized"): if not hasattr(m, "to_quantized"):
@ -547,11 +544,15 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config) config = load_config(model_path)
config.update(model_config)
model = load_model(model_path, lazy, config)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config)
return model, tokenizer return model, tokenizer
@ -559,9 +560,9 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = load_config(model_path) config = load_config(model_path)
tokenizer = load_tokenizer(model_path) model = load_model(model_path, lazy, model_config=config)
tokenizer = load_tokenizer(model_path, model_config=config)
return model, config, tokenizer return model, config, tokenizer