From 2cd793dd6961487f78982ecff8843ee06d7b4f9a Mon Sep 17 00:00:00 2001 From: Sugato Ray Date: Thu, 14 Mar 2024 09:36:05 -0400 Subject: [PATCH] feat: add update_config functionality (#531) * feat: add `update_config` finctionality - sorts the config for better readability - updates "_name_or_path" key in config with upload_repo - sets indentation of 4 spaces - allows adding other key-value pairs via kwargs - reduces code duplication - standardizes config-update across mlx-lm * feat: standardize updating config Impactes: - fuse.py - merge.py * update formatting * remove commented out code * update func: update_config to save_config - drop kwards - rename func as save_config - incorporate review suggestions * update func: save_config - ensure only config-saving functionality - function oes not return config as a dict anymore - added review suggestions * fixed formatting * update formatting instruction in contribution guide * nits --------- Co-authored-by: Awni Hannun --- CONTRIBUTING.md | 12 +++++++++++- llms/mlx_lm/fuse.py | 11 ++++++++--- llms/mlx_lm/merge.py | 11 ++++++++--- llms/mlx_lm/utils.py | 26 ++++++++++++++++++++++++-- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 233e4ea9..d2ea3932 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,7 @@ possible. This should install hooks for running `black` and `clang-format` to ensure consistent style for C++ and python code. - You can also run the formatters manually as follows: + You can also run the formatters manually as follows on individual files: ```bash clang-format -i file.cpp @@ -21,6 +21,16 @@ possible. ```bash black file.py ``` + + or, + + ```bash + # single file + pre-commit run --files file1.py + + # specific files + pre-commit run --files file1.py file2.py + ``` or run `pre-commit run --all-files` to check all files in the repo. diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 132982d3..c10b09b2 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -9,7 +9,13 @@ from mlx.utils import tree_flatten, tree_unflatten from .tuner.lora import LoRALinear from .tuner.utils import apply_lora_layers, dequantize -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub +from .utils import ( + fetch_from_hub, + get_model_path, + save_config, + save_weights, + upload_to_hub, +) def parse_arguments() -> argparse.Namespace: @@ -87,8 +93,7 @@ def main() -> None: if args.de_quantize: config.pop("quantization", None) - with open(save_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) + save_config(config, config_path=save_path / "config.json") if args.upload_repo is not None: hf_path = args.hf_path or ( diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index affd034c..c1abdb8a 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -13,7 +13,13 @@ import numpy as np import yaml from mlx.utils import tree_flatten, tree_map -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub +from .utils import ( + fetch_from_hub, + get_model_path, + save_config, + save_weights, + upload_to_hub, +) def configure_parser() -> argparse.ArgumentParser: @@ -151,8 +157,7 @@ def merge( tokenizer.save_pretrained(mlx_path) - with open(mlx_path / "config.json", "w") as fid: - json.dump(base_config, fid, indent=4) + save_config(config, config_path=mlx_path / "config.json") if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, base_hf_path) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index bfbe911b..137c2ddd 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -553,6 +553,29 @@ def quantize_model( return quantized_weights, quantized_config +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + def convert( hf_path: str, mlx_path: str = "mlx_model", @@ -588,8 +611,7 @@ def convert( tokenizer.save_pretrained(mlx_path) - with open(mlx_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) + save_config(config, config_path=mlx_path / "config.json") if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path)