mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
Merge dd1690df81
into e8afb59de4
This commit is contained in:
commit
651f9a5cf8
@ -73,6 +73,13 @@ def build_parser():
|
|||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
default="main",
|
||||||
|
type=str,
|
||||||
|
help="Hash value of the commit to checkout from the Hugging Face repo.",
|
||||||
|
)
|
||||||
|
|
||||||
# Training args
|
# Training args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train",
|
"--train",
|
||||||
@ -252,7 +259,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
print("Loading pretrained model")
|
print("Loading pretrained model")
|
||||||
model, tokenizer = load(args.model)
|
model, tokenizer = load(args.model, args.revision)
|
||||||
|
|
||||||
print("Loading datasets")
|
print("Loading datasets")
|
||||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
|
@ -152,7 +152,7 @@ def compute_bits_per_weight(model):
|
|||||||
return model_bytes * 8 / model_params
|
return model_bytes * 8 / model_params
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Path:
|
||||||
"""
|
"""
|
||||||
Ensures the model is available locally. If the path does not exist locally,
|
Ensures the model is available locally. If the path does not exist locally,
|
||||||
it is downloaded from the Hugging Face Hub.
|
it is downloaded from the Hugging Face Hub.
|
||||||
@ -184,7 +184,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
raise ModelNotFoundError(
|
raise ModelNotFoundError(
|
||||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
f"Model not found for path or HF repo: {path_or_hf_repo}:{revision}.\n"
|
||||||
"Please make sure you specified the local path or Hugging Face"
|
"Please make sure you specified the local path or Hugging Face"
|
||||||
" repo id correctly.\nIf you are trying to access a private or"
|
" repo id correctly.\nIf you are trying to access a private or"
|
||||||
" gated Hugging Face repo, make sure you are authenticated:\n"
|
" gated Hugging Face repo, make sure you are authenticated:\n"
|
||||||
@ -709,6 +709,7 @@ def load(
|
|||||||
model_config={},
|
model_config={},
|
||||||
adapter_path: Optional[str] = None,
|
adapter_path: Optional[str] = None,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
revision: Optional[str] = "main",
|
||||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||||
"""
|
"""
|
||||||
Load the model and tokenizer from a given path or a huggingface repository.
|
Load the model and tokenizer from a given path or a huggingface repository.
|
||||||
|
Loading…
Reference in New Issue
Block a user