llms: convert() add 'revision' argument (#506)

* llms: convert() add 'revision' argument

* Update README.md

* Update utils.py

* Update README.md

* Update llms/mlx_lm/utils.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Miller Liang
2024-03-02 22:28:26 +08:00
committed by GitHub
parent a429263905
commit 5b1043a458

View File

@@ -57,13 +57,14 @@ def _get_classes(config: dict):
return arch.Model, arch.ModelArgs return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path: def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
""" """
Ensures the model is available locally. If the path does not exist locally, Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub. it is downloaded from the Hugging Face Hub.
Args: Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns: Returns:
Path: The path to the model. Path: The path to the model.
@@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path( model_path = Path(
snapshot_download( snapshot_download(
repo_id=path_or_hf_repo, repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[ allow_patterns=[
"*.json", "*.json",
"*.safetensors", "*.safetensors",
@@ -556,9 +558,10 @@ def convert(
q_bits: int = 4, q_bits: int = 4,
dtype: str = "float16", dtype: str = "float16",
upload_repo: str = None, upload_repo: str = None,
revision: Optional[str] = None,
): ):
print("[INFO] Loading") print("[INFO] Loading")
model_path = get_model_path(hf_path) model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))