mlx-examples/lora/convert.py

98 lines
2.4 KiB
Python
Raw Permalink Normal View History

# Copyright © 2023-2024 Apple Inc.
2023-12-01 03:08:53 +08:00
2023-11-30 06:14:11 +08:00
import argparse
2024-01-05 13:05:59 +08:00
import copy
2023-11-30 06:14:11 +08:00
2024-01-05 13:05:59 +08:00
import mlx.core as mx
import mlx.nn as nn
import models
import utils
from mlx.utils import tree_flatten
2024-01-05 13:05:59 +08:00
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = models.Model(models.ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
2024-01-05 13:05:59 +08:00
# Quantize the model:
2024-04-23 09:12:52 +08:00
nn.quantize(
model,
args.q_group_size,
args.q_bits,
)
2024-01-05 13:05:59 +08:00
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
2023-11-30 06:14:11 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
2024-01-05 13:05:59 +08:00
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
2024-01-05 13:05:59 +08:00
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
2023-11-30 06:14:11 +08:00
args = parser.parse_args()
2023-12-19 11:33:17 +08:00
print("[INFO] Loading")
weights, config, tokenizer = utils.fetch_from_hub(args.hf_path)
2024-01-05 13:05:59 +08:00
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
2024-01-05 13:05:59 +08:00
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
utils.save_model(args.mlx_path, weights, tokenizer, config)
if args.upload_name is not None:
utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path)