testing quantization

This commit is contained in:
Awni Hannun
2023-12-20 14:21:40 -08:00
parent 6c574dbecf
commit 8b824fb768
2 changed files with 50 additions and 2 deletions

View File

@@ -114,6 +114,40 @@ def tiny_llama(model_path):
return weights, params
def quantize(weights, config):
import mlx.nn as nn
from llama import LLama, ModelArgs
quantized_config = copy.deepcopy(config)
# Load the model
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
model = Llama(ModelArgs(**config))
model.update(tree_unflatten(list(weights.items())))
nn.QuantizedLinear.quantize_module(model)
# Update the config
quantized_config["quantization"] = {"groups": 128, "width": 4}
quantized_weights = tree_flatten(model.parameters())
mx.eval(quantized_weights)
return quantized_weights, quantized_config
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
@@ -130,12 +164,22 @@ if __name__ == "__main__":
choices=["tiny_llama", "llama"],
default="llama",
)
parser.add_argument(
"-q",
"--quantize",
help="Quantize the model before saving",
action="store_true",
)
args = parser.parse_args()
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
params["model_type"] = "llama"
if args.quantize:
print("[INFO] Quantizing model...")
weights, params = quantize(weights, params)
np.savez(str(model_path / "weights.npz"), **weights)
with open(model_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@@ -345,9 +345,13 @@ def load_model(model_path):
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
if k in config:
config.pop(k)
config.pop(k, None)
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model, groups=quantization["groups"], width=quantization["width"]
)
model.update(tree_unflatten(list(weights.items())))
return model