llms: convert() add 'revision' argument (#506)

* llms: convert() add 'revision' argument

* Update README.md

* Update utils.py

* Update README.md

* Update llms/mlx_lm/utils.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Miller Liang 2024-03-02 22:28:26 +08:00 committed by GitHub
parent a429263905
commit 5b1043a458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -57,13 +57,14 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path:
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Path: The path to the model.
@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
@ -556,9 +558,10 @@ def convert(
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
revision: Optional[str] = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))