Improvements to mlx_lm.manage (#1178)

* improvements to manage. Default value is N and size added to deletion confirmation.

* Fixing case for no case

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Ivan Fioravanti 2025-01-01 16:25:57 +01:00 committed by GitHub
parent d4ef909d4a
commit 3a58c36109
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,19 +6,18 @@ from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool: def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1") y = ("y", "yes", "1")
n = ("n", "no", "0") n = ("n", "no", "0", "")
all_values = y + n + ("",) full_message = f"{message} (y/n) "
full_message = f"{message} (Y/n) "
while True: while True:
answer = input(full_message).lower() answer = input(full_message).lower()
if answer == "":
return False
if answer in y: if answer in y:
return True return True
if answer in n: if answer in n:
return False return False
print(f"Invalid input. Must be one of {all_values}") print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main(): def main():
@ -43,9 +42,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.scan: if args.scan:
print( print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
hf_cache_info = scan_cache_dir() hf_cache_info = scan_cache_dir()
print( print(
tabulate( tabulate(
@ -86,35 +83,41 @@ def main():
if args.pattern in repo.repo_id if args.pattern in repo.repo_id
] ]
if repos: if repos:
print("\nFound the following models:")
print( print(
tabulate( tabulate(
rows=[ rows=[
[ [
repo.repo_id, repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path), str(repo.repo_path),
] ]
for repo in repos for repo in repos
], ],
headers=[ headers=[
"REPO ID", "REPO ID",
"SIZE", # Added size header
"LOCAL PATH", "LOCAL PATH",
], ],
) )
) )
confirmed = ask_for_confirmation(f"Confirm deletion ?") confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed: if confirmed:
for model_info in repos: for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted( for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash model_info.revisions, key=lambda revision: revision.commit_hash
): ):
strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute() strategy.execute()
print("Model(s) deleted.") print("\nModel(s) deleted successfully.")
else: else:
print("Deletion is cancelled. Do nothing.") print("\nDeletion cancelled - no changes made.")
else: else:
print(f"No models found.") print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__": if __name__ == "__main__":