Support HuggingFace model tree (#957)

* Hub: Update quantization configuration fields

* Hub: add base_model metadata

* Hub: add quantization_config for model tree Quantized type

* Hub: update quantization_config value

* Hub: remove config print
This commit is contained in:
madroid 2024-09-04 21:19:32 +08:00 committed by GitHub
parent 83a209e200
commit bd29aec299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
card = ModelCard.load(hf_path) card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent( card.text = dedent(
f""" f"""
# {upload_repo} # {upload_repo}
@ -666,6 +667,8 @@ def quantize_model(
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits) nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config return quantized_weights, quantized_config