mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Utilze a specific model version from HuggingFace during LoRA fine-tuning
This commit is contained in:
parent
5cae0a60e6
commit
28eebfe5bf
@ -22,6 +22,13 @@ def build_parser():
|
|||||||
default="mlx_model",
|
default="mlx_model",
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
default="main",
|
||||||
|
help="Specify the version of the model to use. This can be a branch name, tag, or commit hash. Defaults to 'main'.",
|
||||||
|
)
|
||||||
|
|
||||||
# Generation args
|
# Generation args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
@ -333,7 +340,7 @@ if __name__ == "__main__":
|
|||||||
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
|
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
|
||||||
|
|
||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
|
model, tokenizer, _ = lora_utils.load(args.model, args.revision, tokenizer_config)
|
||||||
# Freeze all layers other than LORA linears
|
# Freeze all layers other than LORA linears
|
||||||
model.freeze()
|
model.freeze()
|
||||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||||
|
@ -13,9 +13,10 @@ import transformers
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def fetch_from_hub(hf_path: str):
|
def fetch_from_hub(hf_path: str, revision: str = "main"):
|
||||||
model_path = snapshot_download(
|
model_path = snapshot_download(
|
||||||
repo_id=hf_path,
|
repo_id=hf_path,
|
||||||
|
revision=revision,
|
||||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||||
)
|
)
|
||||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||||
@ -122,7 +123,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_hf_repo: str, tokenizer_config={}):
|
def load(path_or_hf_repo: str, revision: str = "main", tokenizer_config={}):
|
||||||
# If the path exists, it will try to load model form it
|
# If the path exists, it will try to load model form it
|
||||||
# otherwise download and cache from the hf_repo and cache
|
# otherwise download and cache from the hf_repo and cache
|
||||||
model_path = Path(path_or_hf_repo)
|
model_path = Path(path_or_hf_repo)
|
||||||
@ -130,6 +131,7 @@ def load(path_or_hf_repo: str, tokenizer_config={}):
|
|||||||
model_path = Path(
|
model_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=path_or_hf_repo,
|
repo_id=path_or_hf_repo,
|
||||||
|
revision=revision,
|
||||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user