Add -local flag to llms/hf_llm/convert.py for reading source HF models from filesystem. (#260)

* * Add --local flag for reading models from filesystem and related code for doing so
* Disable uploading to huggingface if --local flag is set

* Remove code related to .bin files and merge fetch_from_local and fetch_from_hub into one function.

* Update llms/hf_llm/convert.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* format / nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Konstantin Kerekovski 2024-01-10 22:53:01 -05:00 committed by GitHub
parent 80d18671ad
commit 047d4650c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,11 +14,13 @@ from mlx.utils import tree_flatten
from models import Model, ModelArgs from models import Model, ModelArgs
def fetch_from_hub(hf_path: str): def fetch_from_hub(model_path: str, local: bool):
model_path = snapshot_download( if not local:
repo_id=hf_path, model_path = snapshot_download(
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], repo_id=model_path,
) allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors") weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0: if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path)) raise FileNotFoundError("No safetensors found in {}".format(model_path))
@ -149,11 +151,18 @@ if __name__ == "__main__":
type=str, type=str,
default=None, default=None,
) )
parser.add_argument(
"-l",
"--local",
action="store_true",
help="Whether the hf-path points to a local filesystem.",
default=False,
)
args = parser.parse_args() args = parser.parse_args()
print("[INFO] Loading") print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path) weights, config, tokenizer = fetch_from_hub(args.hf_path, args.local)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}
@ -170,5 +179,5 @@ if __name__ == "__main__":
with open(mlx_path / "config.json", "w") as fid: with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4) json.dump(config, fid, indent=4)
if args.upload_name is not None: if args.upload_name is not None and not args.local:
upload_to_hub(mlx_path, args.upload_name, args.hf_path) upload_to_hub(mlx_path, args.upload_name, args.hf_path)