mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
nits
This commit is contained in:
parent
facaf01b8d
commit
9b53599e6c
@ -327,7 +327,7 @@ def _is_bpe_decoder(decoder):
|
|||||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
|
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
detokenizer to use.
|
detokenizer to use.
|
||||||
|
|
||||||
@ -348,10 +348,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
|
|||||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = BPEStreamingDetokenizer
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
eos_token_ids = (
|
if isinstance(eos_token_ids, int):
|
||||||
set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
|
eos_token_ids = [eos_token_ids]
|
||||||
)
|
|
||||||
|
|
||||||
return TokenizerWrapper(
|
return TokenizerWrapper(
|
||||||
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
|
@ -456,11 +456,11 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
model_config (dict, optional): Configuration parameters for the model.
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
A function that returns the model class and model args class given a config.
|
A function that returns the model class and model args class given a config.
|
||||||
Defaults to the _get_classes function.
|
Defaults to the ``_get_classes`` function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The loaded and initialized model.
|
nn.Module: The loaded and initialized model.
|
||||||
@ -469,6 +469,8 @@ def load_model(
|
|||||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||||
"""
|
"""
|
||||||
|
config = load_config(model_path)
|
||||||
|
config.update(model_config)
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||||
|
|
||||||
@ -484,15 +486,15 @@ def load_model(
|
|||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
|
||||||
model_class, model_args_class = get_model_classes(config=model_config)
|
model_class, model_args_class = get_model_classes(config=config)
|
||||||
|
|
||||||
model_args = model_args_class.from_dict(model_config)
|
model_args = model_args_class.from_dict(config)
|
||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
|
|
||||||
if hasattr(model, "sanitize"):
|
if hasattr(model, "sanitize"):
|
||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if (quantization := model_config.get("quantization", None)) is not None:
|
if (quantization := config.get("quantization", None)) is not None:
|
||||||
# Handle legacy models which may not have everything quantized
|
# Handle legacy models which may not have everything quantized
|
||||||
def class_predicate(p, m):
|
def class_predicate(p, m):
|
||||||
if not hasattr(m, "to_quantized"):
|
if not hasattr(m, "to_quantized"):
|
||||||
@ -511,7 +513,7 @@ def load_model(
|
|||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
@ -544,15 +546,12 @@ def load(
|
|||||||
"""
|
"""
|
||||||
model_path = get_model_path(path_or_hf_repo)
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
config = load_config(model_path)
|
model, config = load_model(model_path, lazy)
|
||||||
config.update(model_config)
|
|
||||||
|
|
||||||
model = load_model(model_path, lazy, config)
|
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(
|
||||||
model_path, tokenizer_config, eos_token_id=config["eos_token_id"]
|
model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
@ -561,9 +560,10 @@ def load(
|
|||||||
def fetch_from_hub(
|
def fetch_from_hub(
|
||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
config = load_config(model_path)
|
model, config = load_model(model_path, lazy)
|
||||||
model = load_model(model_path, lazy, model_config=config)
|
tokenizer = load_tokenizer(
|
||||||
tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"])
|
model_path, eos_token_ids=config.get("eos_token_id", None)
|
||||||
|
)
|
||||||
return model, config, tokenizer
|
return model, config, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
|||||||
return CustomQwenModel, CustomQwenConfig
|
return CustomQwenModel, CustomQwenConfig
|
||||||
|
|
||||||
model_path = get_model_path(HF_MODEL_PATH)
|
model_path = get_model_path(HF_MODEL_PATH)
|
||||||
model = load_model(model_path, get_model_classes=custom_get_classes)
|
model, _ = load_model(model_path, get_model_classes=custom_get_classes)
|
||||||
|
|
||||||
self.assertIsInstance(model, CustomQwenModel)
|
self.assertIsInstance(model, CustomQwenModel)
|
||||||
self.assertTrue(hasattr(model, "custom_attribute"))
|
self.assertTrue(hasattr(model, "custom_attribute"))
|
||||||
@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
|||||||
|
|
||||||
def test_load_model_with_default_get_classes(self):
|
def test_load_model_with_default_get_classes(self):
|
||||||
model_path = get_model_path(HF_MODEL_PATH)
|
model_path = get_model_path(HF_MODEL_PATH)
|
||||||
model = load_model(model_path)
|
model, _ = load_model(model_path)
|
||||||
|
|
||||||
self.assertIsInstance(model, Qwen2Model)
|
self.assertIsInstance(model, Qwen2Model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user