diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 233e4ea9..d2ea3932 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,7 @@ possible. This should install hooks for running `black` and `clang-format` to ensure consistent style for C++ and python code. - You can also run the formatters manually as follows: + You can also run the formatters manually as follows on individual files: ```bash clang-format -i file.cpp @@ -21,6 +21,16 @@ possible. ```bash black file.py ``` + + or, + + ```bash + # single file + pre-commit run --files file1.py + + # specific files + pre-commit run --files file1.py file2.py + ``` or run `pre-commit run --all-files` to check all files in the repo. diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 132982d3..c10b09b2 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -9,7 +9,13 @@ from mlx.utils import tree_flatten, tree_unflatten from .tuner.lora import LoRALinear from .tuner.utils import apply_lora_layers, dequantize -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub +from .utils import ( + fetch_from_hub, + get_model_path, + save_config, + save_weights, + upload_to_hub, +) def parse_arguments() -> argparse.Namespace: @@ -87,8 +93,7 @@ def main() -> None: if args.de_quantize: config.pop("quantization", None) - with open(save_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) + save_config(config, config_path=save_path / "config.json") if args.upload_repo is not None: hf_path = args.hf_path or ( diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py index affd034c..c1abdb8a 100644 --- a/llms/mlx_lm/merge.py +++ b/llms/mlx_lm/merge.py @@ -13,7 +13,13 @@ import numpy as np import yaml from mlx.utils import tree_flatten, tree_map -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub +from .utils import ( + fetch_from_hub, + get_model_path, + save_config, + save_weights, + upload_to_hub, +) def configure_parser() -> argparse.ArgumentParser: @@ -151,8 +157,7 @@ def merge( tokenizer.save_pretrained(mlx_path) - with open(mlx_path / "config.json", "w") as fid: - json.dump(base_config, fid, indent=4) + save_config(config, config_path=mlx_path / "config.json") if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, base_hf_path) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index bfbe911b..137c2ddd 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -553,6 +553,29 @@ def quantize_model( return quantized_weights, quantized_config +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + def convert( hf_path: str, mlx_path: str = "mlx_model", @@ -588,8 +611,7 @@ def convert( tokenizer.save_pretrained(mlx_path) - with open(mlx_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) + save_config(config, config_path=mlx_path / "config.json") if upload_repo is not None: upload_to_hub(mlx_path, upload_repo, hf_path)