mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
feat: add update_config functionality (#531)
* feat: add `update_config` finctionality - sorts the config for better readability - updates "_name_or_path" key in config with upload_repo - sets indentation of 4 spaces - allows adding other key-value pairs via kwargs - reduces code duplication - standardizes config-update across mlx-lm * feat: standardize updating config Impactes: - fuse.py - merge.py * update formatting * remove commented out code * update func: update_config to save_config - drop kwards - rename func as save_config - incorporate review suggestions * update func: save_config - ensure only config-saving functionality - function oes not return config as a dict anymore - added review suggestions * fixed formatting * update formatting instruction in contribution guide * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
485180ae91
commit
2cd793dd69
@ -12,7 +12,7 @@ possible.
|
|||||||
This should install hooks for running `black` and `clang-format` to ensure
|
This should install hooks for running `black` and `clang-format` to ensure
|
||||||
consistent style for C++ and python code.
|
consistent style for C++ and python code.
|
||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows on individual files:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
@ -21,6 +21,16 @@ possible.
|
|||||||
```bash
|
```bash
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
or,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# single file
|
||||||
|
pre-commit run --files file1.py
|
||||||
|
|
||||||
|
# specific files
|
||||||
|
pre-commit run --files file1.py file2.py
|
||||||
|
```
|
||||||
|
|
||||||
or run `pre-commit run --all-files` to check all files in the repo.
|
or run `pre-commit run --all-files` to check all files in the repo.
|
||||||
|
|
||||||
|
@ -9,7 +9,13 @@ from mlx.utils import tree_flatten, tree_unflatten
|
|||||||
|
|
||||||
from .tuner.lora import LoRALinear
|
from .tuner.lora import LoRALinear
|
||||||
from .tuner.utils import apply_lora_layers, dequantize
|
from .tuner.utils import apply_lora_layers, dequantize
|
||||||
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
from .utils import (
|
||||||
|
fetch_from_hub,
|
||||||
|
get_model_path,
|
||||||
|
save_config,
|
||||||
|
save_weights,
|
||||||
|
upload_to_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> argparse.Namespace:
|
def parse_arguments() -> argparse.Namespace:
|
||||||
@ -87,8 +93,7 @@ def main() -> None:
|
|||||||
if args.de_quantize:
|
if args.de_quantize:
|
||||||
config.pop("quantization", None)
|
config.pop("quantization", None)
|
||||||
|
|
||||||
with open(save_path / "config.json", "w") as fid:
|
save_config(config, config_path=save_path / "config.json")
|
||||||
json.dump(config, fid, indent=4)
|
|
||||||
|
|
||||||
if args.upload_repo is not None:
|
if args.upload_repo is not None:
|
||||||
hf_path = args.hf_path or (
|
hf_path = args.hf_path or (
|
||||||
|
@ -13,7 +13,13 @@ import numpy as np
|
|||||||
import yaml
|
import yaml
|
||||||
from mlx.utils import tree_flatten, tree_map
|
from mlx.utils import tree_flatten, tree_map
|
||||||
|
|
||||||
from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub
|
from .utils import (
|
||||||
|
fetch_from_hub,
|
||||||
|
get_model_path,
|
||||||
|
save_config,
|
||||||
|
save_weights,
|
||||||
|
upload_to_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_parser() -> argparse.ArgumentParser:
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
@ -151,8 +157,7 @@ def merge(
|
|||||||
|
|
||||||
tokenizer.save_pretrained(mlx_path)
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
with open(mlx_path / "config.json", "w") as fid:
|
save_config(config, config_path=mlx_path / "config.json")
|
||||||
json.dump(base_config, fid, indent=4)
|
|
||||||
|
|
||||||
if upload_repo is not None:
|
if upload_repo is not None:
|
||||||
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||||
|
@ -553,6 +553,29 @@ def quantize_model(
|
|||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(
|
||||||
|
config: dict,
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
# Clean unused keys
|
||||||
|
config.pop("_name_or_path", None)
|
||||||
|
|
||||||
|
# sort the config for better readability
|
||||||
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
# write the updated config to the config_path (if provided)
|
||||||
|
with open(config_path, "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
def convert(
|
||||||
hf_path: str,
|
hf_path: str,
|
||||||
mlx_path: str = "mlx_model",
|
mlx_path: str = "mlx_model",
|
||||||
@ -588,8 +611,7 @@ def convert(
|
|||||||
|
|
||||||
tokenizer.save_pretrained(mlx_path)
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
with open(mlx_path / "config.json", "w") as fid:
|
save_config(config, config_path=mlx_path / "config.json")
|
||||||
json.dump(config, fid, indent=4)
|
|
||||||
|
|
||||||
if upload_repo is not None:
|
if upload_repo is not None:
|
||||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user