mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
feat: add update_config functionality (#531)
* feat: add `update_config` finctionality - sorts the config for better readability - updates "_name_or_path" key in config with upload_repo - sets indentation of 4 spaces - allows adding other key-value pairs via kwargs - reduces code duplication - standardizes config-update across mlx-lm * feat: standardize updating config Impactes: - fuse.py - merge.py * update formatting * remove commented out code * update func: update_config to save_config - drop kwards - rename func as save_config - incorporate review suggestions * update func: save_config - ensure only config-saving functionality - function oes not return config as a dict anymore - added review suggestions * fixed formatting * update formatting instruction in contribution guide * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
485180ae91
commit
2cd793dd69
@ -12,7 +12,7 @@ possible.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
You can also run the formatters manually as follows on individual files:
|
||||
|
||||
```bash
|
||||
clang-format -i file.cpp
|
||||
@ -21,6 +21,16 @@ possible.
|
||||
```bash
|
||||
black file.py
|
||||
```
|
||||
|
||||
or,
|
||||
|
||||
```bash
|
||||
# single file
|
||||
pre-commit run --files file1.py
|
||||
|
||||
# specific files
|
||||
pre-commit run --files file1.py file2.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
|
@ -9,7 +9,13 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .tuner.lora import LoRALinear
|
||||
from .tuner.utils import apply_lora_layers, dequantize
|
||||
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
@ -87,8 +93,7 @@ def main() -> None:
|
||||
if args.de_quantize:
|
||||
config.pop("quantization", None)
|
||||
|
||||
with open(save_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
save_config(config, config_path=save_path / "config.json")
|
||||
|
||||
if args.upload_repo is not None:
|
||||
hf_path = args.hf_path or (
|
||||
|
@ -13,7 +13,13 @@ import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
@ -151,8 +157,7 @@ def merge(
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(base_config, fid, indent=4)
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||
|
@ -553,6 +553,29 @@ def quantize_model(
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
def save_config(
|
||||
config: dict,
|
||||
config_path: Union[str, Path],
|
||||
) -> None:
|
||||
"""Save the model configuration to the ``config_path``.
|
||||
|
||||
The final configuration will be sorted before saving for better readability.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
config_path (Union[str, Path]): Model configuration file path.
|
||||
"""
|
||||
# Clean unused keys
|
||||
config.pop("_name_or_path", None)
|
||||
|
||||
# sort the config for better readability
|
||||
config = dict(sorted(config.items()))
|
||||
|
||||
# write the updated config to the config_path (if provided)
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
@ -588,8 +611,7 @@ def convert(
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
with open(mlx_path / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
|
Loading…
Reference in New Issue
Block a user