From b468091f7ff0b289352f773b9793c5613437465d Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Fri, 3 May 2024 21:20:13 +0200 Subject: [PATCH] Add model management functionality for local caches (#736) * Add model management functionality for local caches This commit introduces a set of command-line utilities for managing MLX models downloaded and saved locally in Hugging Face cache. The functionalities include scanning existing models, retrieving detailed information about a specific model, and deleting a model by its name. * Added mlx_lm.model to setup.py * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/MANAGE.md | 22 ++++++++ llms/mlx_lm/manage.py | 121 ++++++++++++++++++++++++++++++++++++++++++ llms/setup.py | 1 + 3 files changed, 144 insertions(+) create mode 100644 llms/mlx_lm/MANAGE.md create mode 100644 llms/mlx_lm/manage.py diff --git a/llms/mlx_lm/MANAGE.md b/llms/mlx_lm/MANAGE.md new file mode 100644 index 00000000..00858a0a --- /dev/null +++ b/llms/mlx_lm/MANAGE.md @@ -0,0 +1,22 @@ +# Managing Models + +You can use `mlx-lm` to manage models downloaded locally in your machine. They +are stored in the Hugging Face cache. + +Scan models: + +```shell +mlx_lm.manage --scan +``` + +Specify a `--pattern` to get info on a single or specific set of models: + +```shell +mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit +``` + +To delete a model (or multiple models): + +```shell +mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit +``` diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py new file mode 100644 index 00000000..bb5c3a09 --- /dev/null +++ b/llms/mlx_lm/manage.py @@ -0,0 +1,121 @@ +import argparse +from typing import List, Union + +from huggingface_hub import scan_cache_dir +from transformers.commands.user import tabulate + + +def ask_for_confirmation(message: str) -> bool: + y = ("y", "yes", "1") + n = ("n", "no", "0") + all_values = y + n + ("",) + full_message = f"{message} (Y/n) " + while True: + answer = input(full_message).lower() + if answer == "": + return False + if answer in y: + return True + if answer in n: + return False + print(f"Invalid input. Must be one of {all_values}") + + +def main(): + parser = argparse.ArgumentParser(description="MLX Model Cache.") + parser.add_argument( + "--scan", + action="store_true", + help="Scan Hugging Face cache for mlx models.", + ) + parser.add_argument( + "--delete", + action="store_true", + help="Delete models matching the given pattern.", + ) + parser.add_argument( + "--pattern", + type=str, + help="Model repos contain the pattern.", + default="mlx", + ) + + args = parser.parse_args() + + if args.scan: + print( + "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' + ) + hf_cache_info = scan_cache_dir() + print( + tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + "{:>12}".format(repo.size_on_disk_str), + repo.nb_files, + repo.last_accessed_str, + repo.last_modified_str, + str(repo.repo_path), + ] + for repo in sorted( + hf_cache_info.repos, key=lambda repo: repo.repo_path + ) + if args.pattern in repo.repo_id + ], + headers=[ + "REPO ID", + "REPO TYPE", + "SIZE ON DISK", + "NB FILES", + "LAST_ACCESSED", + "LAST_MODIFIED", + "LOCAL PATH", + ], + ) + ) + + if args.delete: + print(f'Deleting models matching pattern "{args.pattern}"') + hf_cache_info = scan_cache_dir() + + repos = [ + repo + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + if args.pattern in repo.repo_id + ] + if repos: + print( + tabulate( + rows=[ + [ + repo.repo_id, + str(repo.repo_path), + ] + for repo in repos + ], + headers=[ + "REPO ID", + "LOCAL PATH", + ], + ) + ) + + confirmed = ask_for_confirmation(f"Confirm deletion ?") + if confirmed: + for model_info in repos: + for revision in sorted( + model_info.revisions, key=lambda revision: revision.commit_hash + ): + strategy = hf_cache_info.delete_revisions(revision.commit_hash) + strategy.execute() + print("Model(s) deleted.") + else: + print("Deletion is cancelled. Do nothing.") + else: + print(f"No models found.") + + +if __name__ == "__main__": + main() diff --git a/llms/setup.py b/llms/setup.py index c4e5d075..648e1e04 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -34,6 +34,7 @@ setup( "mlx_lm.lora = mlx_lm.lora:main", "mlx_lm.merge = mlx_lm.merge:main", "mlx_lm.server = mlx_lm.server:main", + "mlx_lm.manage = mlx_lm.manage:main", ] }, )