Add -local flag to llms/hf_llm/convert.py for reading source HF models from filesystem. (#260)

* * Add --local flag for reading models from filesystem and related code for doing so
* Disable uploading to huggingface if --local flag is set

* Remove code related to .bin files and merge fetch_from_local and fetch_from_hub into one function.

* Update llms/hf_llm/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* format / nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Konstantin Kerekovski 2024-01-10 22:53:01 -05:00 committed by GitHub
parent 80d18671ad
commit 047d4650c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,11 +14,13 @@ from mlx.utils import tree_flatten
from models import Model, ModelArgs
def fetch_from_hub(hf_path: str):
def fetch_from_hub(model_path: str, local: bool):
if not local:
model_path = snapshot_download(
repo_id=hf_path,
repo_id=model_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
@ -149,11 +151,18 @@ if __name__ == "__main__":
type=str,
default=None,
)
parser.add_argument(
"-l",
"--local",
action="store_true",
help="Whether the hf-path points to a local filesystem.",
default=False,
)
args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
@ -170,5 +179,5 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_name is not None:
if args.upload_name is not None and not args.local:
upload_to_hub(mlx_path, args.upload_name, args.hf_path)