mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Support for multiple EOS tokens
This commit is contained in:
parent
1727959a27
commit
f8379fb3ef
@ -254,21 +254,29 @@ class TokenizerWrapper:
|
|||||||
huggingface tokenizer.
|
huggingface tokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
|
def __init__(
|
||||||
|
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
|
||||||
|
):
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._detokenizer = detokenizer_class(tokenizer)
|
self._detokenizer = detokenizer_class(tokenizer)
|
||||||
|
self._eos_token_ids = eos_token_ids or [tokenizer.eos_token_id]
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
if attr == "detokenizer":
|
if attr == "detokenizer":
|
||||||
return self._detokenizer
|
return self._detokenizer
|
||||||
|
elif attr == "eos_token_ids":
|
||||||
|
return self._eos_token_ids
|
||||||
elif attr.startswith("_"):
|
elif attr.startswith("_"):
|
||||||
return self.__getattribute__(attr)
|
return self.__getattribute__(attr)
|
||||||
else:
|
else:
|
||||||
return getattr(self._tokenizer, attr)
|
return getattr(self._tokenizer, attr)
|
||||||
|
|
||||||
def __setattr__(self, attr, value):
|
def __setattr__(self, attr, value):
|
||||||
if attr == "detokenizer":
|
if attr in {"detokenizer", "eos_token_ids"}:
|
||||||
raise AttributeError("Cannot set the detokenizer.")
|
if attr == "detokenizer":
|
||||||
|
raise AttributeError("Cannot set the detokenizer.")
|
||||||
|
elif attr == "eos_token_ids":
|
||||||
|
self._eos_token_ids = value
|
||||||
elif attr.startswith("_"):
|
elif attr.startswith("_"):
|
||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
else:
|
else:
|
||||||
@ -315,7 +323,7 @@ def _is_bpe_decoder(decoder):
|
|||||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
|
||||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
detokenizer to use.
|
detokenizer to use.
|
||||||
|
|
||||||
@ -336,7 +344,11 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
|
|||||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = BPEStreamingDetokenizer
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
|
eos_token_id = model_config["eos_token_id"]
|
||||||
|
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
|
||||||
|
|
||||||
return TokenizerWrapper(
|
return TokenizerWrapper(
|
||||||
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
|
eos_token_ids=eos_token_ids,
|
||||||
)
|
)
|
||||||
|
@ -350,7 +350,7 @@ def stream_generate(
|
|||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
if token == tokenizer.eos_token_id:
|
if token in tokenizer.eos_token_ids:
|
||||||
break
|
break
|
||||||
|
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
@ -470,9 +470,6 @@ def load_model(
|
|||||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = load_config(model_path)
|
|
||||||
config.update(model_config)
|
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files:
|
||||||
@ -487,15 +484,15 @@ def load_model(
|
|||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
|
||||||
model_class, model_args_class = get_model_classes(config=config)
|
model_class, model_args_class = get_model_classes(config=model_config)
|
||||||
|
|
||||||
model_args = model_args_class.from_dict(config)
|
model_args = model_args_class.from_dict(model_config)
|
||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
|
|
||||||
if hasattr(model, "sanitize"):
|
if hasattr(model, "sanitize"):
|
||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if (quantization := config.get("quantization", None)) is not None:
|
if (quantization := model_config.get("quantization", None)) is not None:
|
||||||
# Handle legacy models which may not have everything quantized
|
# Handle legacy models which may not have everything quantized
|
||||||
def class_predicate(p, m):
|
def class_predicate(p, m):
|
||||||
if not hasattr(m, "to_quantized"):
|
if not hasattr(m, "to_quantized"):
|
||||||
@ -547,11 +544,15 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
model = load_model(model_path, lazy, model_config)
|
config = load_config(model_path)
|
||||||
|
config.update(model_config)
|
||||||
|
|
||||||
|
model = load_model(model_path, lazy, config)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
|
||||||
|
tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -559,9 +560,9 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model = load_model(model_path, lazy)
|
|
||||||
config = load_config(model_path)
|
config = load_config(model_path)
|
||||||
tokenizer = load_tokenizer(model_path)
|
model = load_model(model_path, lazy, model_config=config)
|
||||||
|
tokenizer = load_tokenizer(model_path, model_config=config)
|
||||||
return model, config, tokenizer
|
return model, config, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user