mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add -local flag to llms/hf_llm/convert.py for reading source HF models from filesystem. (#260)
* * Add --local flag for reading models from filesystem and related code for doing so * Disable uploading to huggingface if --local flag is set * Remove code related to .bin files and merge fetch_from_local and fetch_from_hub into one function. * Update llms/hf_llm/convert.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * format / nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
80d18671ad
commit
047d4650c4
@ -14,11 +14,13 @@ from mlx.utils import tree_flatten
|
|||||||
from models import Model, ModelArgs
|
from models import Model, ModelArgs
|
||||||
|
|
||||||
|
|
||||||
def fetch_from_hub(hf_path: str):
|
def fetch_from_hub(model_path: str, local: bool):
|
||||||
|
if not local:
|
||||||
model_path = snapshot_download(
|
model_path = snapshot_download(
|
||||||
repo_id=hf_path,
|
repo_id=model_path,
|
||||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||||
if len(weight_files) == 0:
|
if len(weight_files) == 0:
|
||||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||||
@ -149,11 +151,18 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--local",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether the hf-path points to a local filesystem.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
weights, config, tokenizer = fetch_from_hub(args.hf_path)
|
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
|
||||||
|
|
||||||
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
@ -170,5 +179,5 @@ if __name__ == "__main__":
|
|||||||
with open(mlx_path / "config.json", "w") as fid:
|
with open(mlx_path / "config.json", "w") as fid:
|
||||||
json.dump(config, fid, indent=4)
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
if args.upload_name is not None:
|
if args.upload_name is not None and not args.local:
|
||||||
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
upload_to_hub(mlx_path, args.upload_name, args.hf_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user