mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
@@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str:
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
@@ -132,7 +134,9 @@ def load_torch_model(
|
||||
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
|
||||
name_or_path = _download(_MODELS[name_or_path], download_root)
|
||||
elif not Path(name_or_path).is_file():
|
||||
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
|
||||
raise RuntimeError(
|
||||
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
|
||||
)
|
||||
|
||||
with open(name_or_path, "rb") as fp:
|
||||
checkpoint = torch.load(fp)
|
||||
@@ -259,7 +263,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
||||
assert (
|
||||
args.dtype in _VALID_DTYPES
|
||||
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
|
||||
dtype = getattr(mx, args.dtype)
|
||||
|
||||
print("[INFO] Loading")
|
||||
|
Reference in New Issue
Block a user