mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 03:36:42 +08:00
llms: convert() add 'revision' argument (#506)
* llms: convert() add 'revision' argument * Update README.md * Update utils.py * Update README.md * Update llms/mlx_lm/utils.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
a429263905
commit
5b1043a458
@ -57,13 +57,14 @@ def _get_classes(config: dict):
|
|||||||
return arch.Model, arch.ModelArgs
|
return arch.Model, arch.ModelArgs
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(path_or_hf_repo: str) -> Path:
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Ensures the model is available locally. If the path does not exist locally,
|
Ensures the model is available locally. If the path does not exist locally,
|
||||||
it is downloaded from the Hugging Face Hub.
|
it is downloaded from the Hugging Face Hub.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
||||||
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path: The path to the model.
|
Path: The path to the model.
|
||||||
@ -73,6 +74,7 @@ def get_model_path(path_or_hf_repo: str) -> Path:
|
|||||||
model_path = Path(
|
model_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=path_or_hf_repo,
|
repo_id=path_or_hf_repo,
|
||||||
|
revision=revision,
|
||||||
allow_patterns=[
|
allow_patterns=[
|
||||||
"*.json",
|
"*.json",
|
||||||
"*.safetensors",
|
"*.safetensors",
|
||||||
@ -556,9 +558,10 @@ def convert(
|
|||||||
q_bits: int = 4,
|
q_bits: int = 4,
|
||||||
dtype: str = "float16",
|
dtype: str = "float16",
|
||||||
upload_repo: str = None,
|
upload_repo: str = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
model_path = get_model_path(hf_path)
|
model_path = get_model_path(hf_path, revision=revision)
|
||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
Loading…
Reference in New Issue
Block a user