improvements to manage. Default value is N and size added to deletion confirmation.

This commit is contained in:
ivanfioravanti 2024-12-21 23:34:44 +01:00
parent d4ef909d4a
commit 70e4a6e662

View File

@ -6,10 +6,12 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1")
n = ("n", "no", "0")
all_values = y + n + ("",)
full_message = f"{message} (Y/n) "
full_message = f"{message} (y/N) "
while True:
answer = input(full_message).lower()
if answer == "":
@ -18,7 +20,7 @@ def ask_for_confirmation(message: str) -> bool:
return True
if answer in n:
return False
print(f"Invalid input. Must be one of {all_values}")
print(f"Invalid input. Must be one of: yes/no/y/N or empty for no")
def main():
@ -44,7 +46,7 @@ def main():
if args.scan:
print(
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
f'Scanning Hugging Face cache for models with pattern "{args.pattern}".'
)
hf_cache_info = scan_cache_dir()
print(
@ -86,35 +88,39 @@ def main():
if args.pattern in repo.repo_id
]
if repos:
print("\nFound the following models:")
print(
tabulate(
rows=[
[
repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path),
]
for repo in repos
],
headers=[
"REPO ID",
"SIZE", # Added size header
"LOCAL PATH",
],
)
)
confirmed = ask_for_confirmation(f"Confirm deletion ?")
confirmed = ask_for_confirmation("\nAre you sure you want to delete these models?")
if confirmed:
for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash
):
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute()
print("Model(s) deleted.")
print("\nModel(s) deleted successfully.")
else:
print("Deletion is cancelled. Do nothing.")
print("\nDeletion cancelled - no changes made.")
else:
print(f"No models found.")
print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__":