mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
llms: convert() add 'revision' argument (#506)
* llms: convert() add 'revision' argument * Update README.md * Update utils.py * Update README.md * Update llms/mlx_lm/utils.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
a429263905
commit
5b1043a458
@ -57,13 +57,14 @@ def _get_classes(config: dict):
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
it is downloaded from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
||||
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
||||
|
||||
Returns:
|
||||
Path: The path to the model.
|
||||
@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
@ -556,9 +558,10 @@ def convert(
|
||||
q_bits: int = 4,
|
||||
dtype: str = "float16",
|
||||
upload_repo: str = None,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
print("[INFO] Loading")
|
||||
model_path = get_model_path(hf_path)
|
||||
model_path = get_model_path(hf_path, revision=revision)
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
Loading…
Reference in New Issue
Block a user