diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 71476df3..eee28c9c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): card = ModelCard.load(hf_path) card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] + card.data.base_model = hf_path card.text = dedent( f""" # {upload_repo} @@ -666,6 +667,8 @@ def quantize_model( quantized_config = copy.deepcopy(config) nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + # support hf model tree #957 + quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) return quantized_weights, quantized_config