From 70e4a6e6628a29b8852ce964899946328d6d7caa Mon Sep 17 00:00:00 2001 From: ivanfioravanti Date: Sat, 21 Dec 2024 23:34:44 +0100 Subject: [PATCH] improvements to manage. Default value is N and size added to deletion confirmation. --- llms/mlx_lm/manage.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index bb5c3a09..dd741d41 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -6,10 +6,12 @@ from transformers.commands.user import tabulate def ask_for_confirmation(message: str) -> bool: + """Ask user for confirmation with Y/N prompt. + Returns True for Y/yes, False for N/no/empty.""" y = ("y", "yes", "1") n = ("n", "no", "0") all_values = y + n + ("",) - full_message = f"{message} (Y/n) " + full_message = f"{message} (y/N) " while True: answer = input(full_message).lower() if answer == "": @@ -18,7 +20,7 @@ def ask_for_confirmation(message: str) -> bool: return True if answer in n: return False - print(f"Invalid input. Must be one of {all_values}") + print(f"Invalid input. Must be one of: yes/no/y/N or empty for no") def main(): @@ -44,7 +46,7 @@ def main(): if args.scan: print( - "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' + f'Scanning Hugging Face cache for models with pattern "{args.pattern}".' ) hf_cache_info = scan_cache_dir() print( @@ -86,35 +88,39 @@ def main(): if args.pattern in repo.repo_id ] if repos: + print("\nFound the following models:") print( tabulate( rows=[ [ repo.repo_id, + repo.size_on_disk_str, # Added size information str(repo.repo_path), ] for repo in repos ], headers=[ "REPO ID", + "SIZE", # Added size header "LOCAL PATH", ], ) ) - confirmed = ask_for_confirmation(f"Confirm deletion ?") + confirmed = ask_for_confirmation("\nAre you sure you want to delete these models?") if confirmed: for model_info in repos: + print(f"\nDeleting {model_info.repo_id}...") for revision in sorted( model_info.revisions, key=lambda revision: revision.commit_hash ): strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy.execute() - print("Model(s) deleted.") + print("\nModel(s) deleted successfully.") else: - print("Deletion is cancelled. Do nothing.") + print("\nDeletion cancelled - no changes made.") else: - print(f"No models found.") + print(f'No models found matching pattern "{args.pattern}"') if __name__ == "__main__":