llms: convert() add 'revision' argument (#506)

* llms: convert() add 'revision' argument

* Update README.md

* Update utils.py

* Update README.md

* Update llms/mlx_lm/utils.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Miller Liang 2024-03-02 22:28:26 +08:00 committed by GitHub
parent a429263905
commit 5b1043a458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,13 +57,14 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub. it is downloaded from the Hugging Face Hub.
Args: Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns: Returns:
Path: The path to the model. Path: The path to the model.
@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[ allow_patterns=[
"*.json", "*.json",
"*.safetensors", "*.safetensors",
@ -556,9 +558,10 @@ def convert(
q_bits: int = 4, q_bits: int = 4,
dtype: str = "float16", dtype: str = "float16",
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None,
): ):
print("[INFO] Loading") print("[INFO] Loading")
model_path = get_model_path(hf_path) model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))