This commit is contained in:
Awni Hannun 2024-12-09 08:51:22 -08:00
parent facaf01b8d
commit 9b53599e6c
3 changed files with 20 additions and 22 deletions

View File

@ -327,7 +327,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None): def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
"""Load a huggingface tokenizer and try to infer the type of streaming """Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use. detokenizer to use.
@ -348,10 +348,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
elif _is_bpe_decoder(tokenizer_content["decoder"]): elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer detokenizer_class = BPEStreamingDetokenizer
eos_token_ids = ( if isinstance(eos_token_ids, int):
set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id} eos_token_ids = [eos_token_ids]
)
return TokenizerWrapper( return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class, detokenizer_class,

View File

@ -456,11 +456,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model. model_config (dict, optional): Optional configuration parameters for the
Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config. A function that returns the model class and model args class given a config.
Defaults to the _get_classes function. Defaults to the ``_get_classes`` function.
Returns: Returns:
nn.Module: The loaded and initialized model. nn.Module: The loaded and initialized model.
@ -469,6 +469,8 @@ def load_model(
FileNotFoundError: If the weight files (.safetensors) are not found. FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated. ValueError: If the model class or args class are not found or cannot be instantiated.
""" """
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors")) weight_files = glob.glob(str(model_path / "model*.safetensors"))
@ -484,15 +486,15 @@ def load_model(
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
model_class, model_args_class = get_model_classes(config=model_config) model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(model_config) model_args = model_args_class.from_dict(config)
model = model_class(model_args) model = model_class(model_args)
if hasattr(model, "sanitize"): if hasattr(model, "sanitize"):
weights = model.sanitize(weights) weights = model.sanitize(weights)
if (quantization := model_config.get("quantization", None)) is not None: if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized # Handle legacy models which may not have everything quantized
def class_predicate(p, m): def class_predicate(p, m):
if not hasattr(m, "to_quantized"): if not hasattr(m, "to_quantized"):
@ -511,7 +513,7 @@ def load_model(
mx.eval(model.parameters()) mx.eval(model.parameters())
model.eval() model.eval()
return model return model, config
def load( def load(
@ -544,15 +546,12 @@ def load(
""" """
model_path = get_model_path(path_or_hf_repo) model_path = get_model_path(path_or_hf_repo)
config = load_config(model_path) model, config = load_model(model_path, lazy)
config.update(model_config)
model = load_model(model_path, lazy, config)
if adapter_path is not None: if adapter_path is not None:
model = load_adapters(model, adapter_path) model = load_adapters(model, adapter_path)
model.eval() model.eval()
tokenizer = load_tokenizer( tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_id=config["eos_token_id"] model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
) )
return model, tokenizer return model, tokenizer
@ -561,9 +560,10 @@ def load(
def fetch_from_hub( def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
config = load_config(model_path) model, config = load_model(model_path, lazy)
model = load_model(model_path, lazy, model_config=config) tokenizer = load_tokenizer(
tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"]) model_path, eos_token_ids=config.get("eos_token_id", None)
)
return model, config, tokenizer return model, config, tokenizer

View File

@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
return CustomQwenModel, CustomQwenConfig return CustomQwenModel, CustomQwenConfig
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path, get_model_classes=custom_get_classes) model, _ = load_model(model_path, get_model_classes=custom_get_classes)
self.assertIsInstance(model, CustomQwenModel) self.assertIsInstance(model, CustomQwenModel)
self.assertTrue(hasattr(model, "custom_attribute")) self.assertTrue(hasattr(model, "custom_attribute"))
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
def test_load_model_with_default_get_classes(self): def test_load_model_with_default_get_classes(self):
model_path = get_model_path(HF_MODEL_PATH) model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path) model, _ = load_model(model_path)
self.assertIsInstance(model, Qwen2Model) self.assertIsInstance(model, Qwen2Model)