diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f4e0658d..90327436 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -57,13 +57,14 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs -def get_model_path(path_or_hf_repo: str) -> Path: +def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. Args: path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. + revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. Returns: Path: The path to the model. @@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path: model_path = Path( snapshot_download( repo_id=path_or_hf_repo, + revision=revision, allow_patterns=[ "*.json", "*.safetensors", @@ -556,9 +558,10 @@ def convert( q_bits: int = 4, dtype: str = "float16", upload_repo: str = None, + revision: Optional[str] = None, ): print("[INFO] Loading") - model_path = get_model_path(hf_path) + model_path = get_model_path(hf_path, revision=revision) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters()))