mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add -local flag to llms/hf_llm/convert.py for reading source HF models from filesystem. (#260)
* * Add --local flag for reading models from filesystem and related code for doing so * Disable uploading to huggingface if --local flag is set * Remove code related to .bin files and merge fetch_from_local and fetch_from_hub into one function. * Update llms/hf_llm/convert.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * format / nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
80d18671ad
commit
047d4650c4
@ -14,11 +14,13 @@ from mlx.utils import tree_flatten
|
||||
from models import Model, ModelArgs
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
def fetch_from_hub(model_path: str, local: bool):
|
||||
if not local:
|
||||
model_path = snapshot_download(
|
||||
repo_id=hf_path,
|
||||
repo_id=model_path,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
@ -149,11 +151,18 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--local",
|
||||
action="store_true",
|
||||
help="Whether the hf-path points to a local filesystem.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("[INFO] Loading")
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
||||
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
|
||||
|
||||
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
@ -170,5 +179,5 @@ if __name__ == "__main__":
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
if args.upload_name is not None:
|
||||
if args.upload_name is not None and not args.local:
|
||||
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
||||
|
Loading…
Reference in New Issue
Block a user