feat: add update_config functionality (#531)

* feat: add `update_config` finctionality

- sorts the config for better readability
- updates "_name_or_path" key in config with upload_repo
- sets indentation of 4 spaces
- allows adding other key-value pairs via kwargs
- reduces code duplication
- standardizes config-update across mlx-lm

* feat: standardize updating config

Impactes:
- fuse.py
- merge.py

* update formatting

* remove commented out code

* update func: update_config to save_config

- drop kwards
- rename func as save_config
- incorporate review suggestions

* update func: save_config

- ensure only config-saving functionality
- function oes not return config as a dict anymore
- added review suggestions

* fixed formatting

* update formatting instruction in contribution guide

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Sugato Ray 2024-03-14 09:36:05 -04:00 committed by GitHub
parent 485180ae91
commit 2cd793dd69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 9 deletions

View File

@ -12,7 +12,7 @@ possible.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
You can also run the formatters manually as follows on individual files:
```bash
clang-format -i file.cpp
@ -22,6 +22,16 @@ possible.
black file.py
```
or,
```bash
# single file
pre-commit run --files file1.py
# specific files
pre-commit run --files file1.py file2.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@ -9,7 +9,13 @@ from mlx.utils import tree_flatten, tree_unflatten
from .tuner.lora import LoRALinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def parse_arguments() -> argparse.Namespace:
@ -87,8 +93,7 @@ def main() -> None:
if args.de_quantize:
config.pop("quantization", None)
with open(save_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
save_config(config, config_path=save_path / "config.json")
if args.upload_repo is not None:
hf_path = args.hf_path or (

View File

@ -13,7 +13,13 @@ import numpy as np
import yaml
from mlx.utils import tree_flatten, tree_map
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
@ -151,8 +157,7 @@ def merge(
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(base_config, fid, indent=4)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, base_hf_path)

View File

@ -553,6 +553,29 @@ def quantize_model(
return quantized_weights, quantized_config
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
@ -588,8 +611,7 @@ def convert(
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)