Support Hugging Face models (#215)

* support hf direct models
This commit is contained in:
Awni Hannun
2024-01-03 15:13:26 -08:00
committed by GitHub
parent 1d09c4fecd
commit a5d6d0436c
16 changed files with 654 additions and 27 deletions

View File

@@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
@@ -132,7 +134,9 @@ def load_torch_model(
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
raise RuntimeError(
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
@@ -259,7 +263,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
assert (
args.dtype in _VALID_DTYPES
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")