mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 07:44:34 +08:00
Support for multiple EOS tokens
This commit is contained in:
@@ -350,7 +350,7 @@ def stream_generate(
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
if token in tokenizer.eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
@@ -470,9 +470,6 @@ def load_model(
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
|
||||
config = load_config(model_path)
|
||||
config.update(model_config)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
@@ -487,15 +484,15 @@ def load_model(
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = get_model_classes(config=config)
|
||||
model_class, model_args_class = get_model_classes(config=model_config)
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model_args = model_args_class.from_dict(model_config)
|
||||
model = model_class(model_args)
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
if (quantization := model_config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
def class_predicate(p, m):
|
||||
if not hasattr(m, "to_quantized"):
|
||||
@@ -547,11 +544,15 @@ def load(
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path, lazy, model_config)
|
||||
config = load_config(model_path)
|
||||
config.update(model_config)
|
||||
|
||||
model = load_model(model_path, lazy, config)
|
||||
if adapter_path is not None:
|
||||
model = load_adapters(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -559,9 +560,9 @@ def load(
|
||||
def fetch_from_hub(
|
||||
model_path: Path, lazy: bool = False
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path, lazy)
|
||||
config = load_config(model_path)
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model = load_model(model_path, lazy, model_config=config)
|
||||
tokenizer = load_tokenizer(model_path, model_config=config)
|
||||
return model, config, tokenizer
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user