mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-20 10:20:46 +08:00
revert revision changes and retain qwen2 support
This commit is contained in:
parent
ba6c7d3aba
commit
ec06c04f4f
@ -73,13 +73,6 @@ def build_parser():
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
default="main",
|
||||
type=str,
|
||||
help="Hash value of the commit to checkout from the Hugging Face repo.",
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
@ -259,7 +252,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model, args.revision)
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
@ -303,4 +296,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
@ -153,7 +153,7 @@ def compute_bits_per_weight(model):
|
||||
return model_bytes * 8 / model_params
|
||||
|
||||
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Path:
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
it is downloaded from the Hugging Face Hub.
|
||||
@ -185,7 +185,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = "main") -> Pa
|
||||
)
|
||||
except:
|
||||
raise ModelNotFoundError(
|
||||
f"Model not found for path or HF repo: {path_or_hf_repo}:{revision}.\n"
|
||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||
"Please make sure you specified the local path or Hugging Face"
|
||||
" repo id correctly.\nIf you are trying to access a private or"
|
||||
" gated Hugging Face repo, make sure you are authenticated:\n"
|
||||
@ -710,7 +710,6 @@ def load(
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
revision: Optional[str] = "main",
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
@ -1028,4 +1027,4 @@ def convert(
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
Loading…
Reference in New Issue
Block a user