This commit is contained in:
Awni Hannun 2024-12-09 08:51:22 -08:00
parent facaf01b8d
commit 9b53599e6c
3 changed files with 20 additions and 22 deletions

View File

@ -327,7 +327,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@ -348,10 +348,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
eos_token_ids = (
set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
)
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,

View File

@ -456,11 +456,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the _get_classes function.
Defaults to the ``_get_classes`` function.
Returns:
nn.Module: The loaded and initialized model.
@ -469,6 +469,8 @@ def load_model(
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
@ -484,15 +486,15 @@ def load_model(
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = get_model_classes(config=model_config)
model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(model_config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if (quantization := model_config.get("quantization", None)) is not None:
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m):
if not hasattr(m, "to_quantized"):
@ -511,7 +513,7 @@ def load_model(
mx.eval(model.parameters())
model.eval()
return model
return model, config
def load(
@ -544,15 +546,12 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)
config = load_config(model_path)
config.update(model_config)
model = load_model(model_path, lazy, config)
model, config = load_model(model_path, lazy)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_id=config["eos_token_id"]
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
)
return model, tokenizer
@ -561,9 +560,10 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
config = load_config(model_path)
model = load_model(model_path, lazy, model_config=config)
tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"])
model, config = load_model(model_path, lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)
return model, config, tokenizer

View File

@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
return CustomQwenModel, CustomQwenConfig
model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path, get_model_classes=custom_get_classes)
model, _ = load_model(model_path, get_model_classes=custom_get_classes)
self.assertIsInstance(model, CustomQwenModel)
self.assertTrue(hasattr(model, "custom_attribute"))
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
def test_load_model_with_default_get_classes(self):
model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path)
model, _ = load_model(model_path)
self.assertIsInstance(model, Qwen2Model)