From 3a58c361096e5be7a927e7719c5ef66bace9a8ab Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 1 Jan 2025 16:25:57 +0100 Subject: [PATCH] Improvements to mlx_lm.manage (#1178) * improvements to manage. Default value is N and size added to deletion confirmation. * Fixing case for no case * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/manage.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index bb5c3a09..9827f3dc 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -6,19 +6,18 @@ from transformers.commands.user import tabulate def ask_for_confirmation(message: str) -> bool: + """Ask user for confirmation with Y/N prompt. + Returns True for Y/yes, False for N/no/empty.""" y = ("y", "yes", "1") - n = ("n", "no", "0") - all_values = y + n + ("",) - full_message = f"{message} (Y/n) " + n = ("n", "no", "0", "") + full_message = f"{message} (y/n) " while True: answer = input(full_message).lower() - if answer == "": - return False if answer in y: return True if answer in n: return False - print(f"Invalid input. Must be one of {all_values}") + print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") def main(): @@ -43,9 +42,7 @@ def main(): args = parser.parse_args() if args.scan: - print( - "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' - ) + print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') hf_cache_info = scan_cache_dir() print( tabulate( @@ -86,35 +83,41 @@ def main(): if args.pattern in repo.repo_id ] if repos: + print("\nFound the following models:") print( tabulate( rows=[ [ repo.repo_id, + repo.size_on_disk_str, # Added size information str(repo.repo_path), ] for repo in repos ], headers=[ "REPO ID", + "SIZE", # Added size header "LOCAL PATH", ], ) ) - confirmed = ask_for_confirmation(f"Confirm deletion ?") + confirmed = ask_for_confirmation( + "\nAre you sure you want to delete these models?" + ) if confirmed: for model_info in repos: + print(f"\nDeleting {model_info.repo_id}...") for revision in sorted( model_info.revisions, key=lambda revision: revision.commit_hash ): strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy.execute() - print("Model(s) deleted.") + print("\nModel(s) deleted successfully.") else: - print("Deletion is cancelled. Do nothing.") + print("\nDeletion cancelled - no changes made.") else: - print(f"No models found.") + print(f'No models found matching pattern "{args.pattern}"') if __name__ == "__main__":