mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Utilze a specific model version from HuggingFace during LoRA fine-tuning
This commit is contained in:
parent
5cae0a60e6
commit
28eebfe5bf
@ -22,6 +22,13 @@ def build_parser():
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
default="main",
|
||||
help="Specify the version of the model to use. This can be a branch name, tag, or commit hash. Defaults to 'main'.",
|
||||
)
|
||||
|
||||
# Generation args
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
@ -333,7 +340,7 @@ if __name__ == "__main__":
|
||||
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
|
||||
model, tokenizer, _ = lora_utils.load(args.model, args.revision, tokenizer_config)
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
|
@ -13,9 +13,10 @@ import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
def fetch_from_hub(hf_path: str, revision: str = "main"):
|
||||
model_path = snapshot_download(
|
||||
repo_id=hf_path,
|
||||
revision=revision,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
@ -122,7 +123,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
)
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str, tokenizer_config={}):
|
||||
def load(path_or_hf_repo: str, revision: str = "main", tokenizer_config={}):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
@ -130,6 +131,7 @@ def load(path_or_hf_repo: str, tokenizer_config={}):
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user