diff --git a/llms/hf_llm/convert.py b/llms/hf_llm/convert.py index f93d01c3..4a899ab8 100644 --- a/llms/hf_llm/convert.py +++ b/llms/hf_llm/convert.py @@ -14,11 +14,13 @@ from mlx.utils import tree_flatten from models import Model, ModelArgs -def fetch_from_hub(hf_path: str): - model_path = snapshot_download( - repo_id=hf_path, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) +def fetch_from_hub(model_path: str, local: bool): + if not local: + model_path = snapshot_download( + repo_id=model_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") if len(weight_files) == 0: raise FileNotFoundError("No safetensors found in {}".format(model_path)) @@ -149,11 +151,18 @@ if __name__ == "__main__": type=str, default=None, ) + parser.add_argument( + "-l", + "--local", + action="store_true", + help="Whether the hf-path points to a local filesystem.", + default=False, + ) args = parser.parse_args() print("[INFO] Loading") - weights, config, tokenizer = fetch_from_hub(args.hf_path) + weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local) dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} @@ -170,5 +179,5 @@ if __name__ == "__main__": with open(mlx_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) - if args.upload_name is not None: + if args.upload_name is not None and not args.local: upload_to_hub(mlx_path, args.upload_name, args.hf_path)