mlx-examples/lora/convert.py
Awni Hannun b8a348c1b8
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
2024-03-23 07:13:51 -07:00

100 lines
2.6 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import argparse
import copy
import mlx.core as mx
import mlx.nn as nn
import models
import utils
from mlx.utils import tree_flatten
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = models.Model(models.ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(
model,
args.q_group_size,
args.q_bits,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = utils.fetch_from_hub(args.hf_path)
dtype = mx.float16 if args.quantize else getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
utils.save_model(args.mlx_path, weights, tokenizer, config)
if args.upload_name is not None:
utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path)