From 5b1043a4584ce2e547ec3b97453b0923764b33e2 Mon Sep 17 00:00:00 2001 From: Miller Liang <841985944@qq.com> Date: Sat, 2 Mar 2024 22:28:26 +0800 Subject: [PATCH] llms: convert() add 'revision' argument (#506) * llms: convert() add 'revision' argument * Update README.md * Update utils.py * Update README.md * Update llms/mlx_lm/utils.py --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f4e0658d..90327436 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -57,13 +57,14 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs -def get_model_path(path_or_hf_repo: str) -> Path: +def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. Args: path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. Returns: Path: The path to the model. @@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path: model_path = Path( snapshot_download( repo_id=path_or_hf_repo, + revision=revision, allow_patterns=[ "*.json", "*.safetensors", @@ -556,9 +558,10 @@ def convert( q_bits: int = 4, dtype: str = "float16", upload_repo: str = None, + revision: Optional[str] = None, ): print("[INFO] Loading") - model_path = get_model_path(hf_path) + model_path = get_model_path(hf_path, revision=revision) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters()))