mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
testing quantization
This commit is contained in:
@@ -114,6 +114,40 @@ def tiny_llama(model_path):
|
||||
return weights, params
|
||||
|
||||
|
||||
def quantize(weights, config):
|
||||
import mlx.nn as nn
|
||||
from llama import LLama, ModelArgs
|
||||
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Load the model
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
config.pop(k, None)
|
||||
model = Llama(ModelArgs(**config))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
nn.QuantizedLinear.quantize_module(model)
|
||||
|
||||
# Update the config
|
||||
quantized_config["quantization"] = {"groups": 128, "width": 4}
|
||||
quantized_weights = tree_flatten(model.parameters())
|
||||
mx.eval(quantized_weights)
|
||||
|
||||
return quantized_weights, quantized_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
@@ -130,12 +164,22 @@ if __name__ == "__main__":
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quantize",
|
||||
help="Quantize the model before saving",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
params["model_type"] = "llama"
|
||||
if args.quantize:
|
||||
print("[INFO] Quantizing model...")
|
||||
weights, params = quantize(weights, params)
|
||||
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
with open(model_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
@@ -345,9 +345,13 @@ def load_model(model_path):
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
config.pop(k, None)
|
||||
quantization = config.pop("quantization", None)
|
||||
model = Llama(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, groups=quantization["groups"], width=quantization["width"]
|
||||
)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
|
||||
|
Reference in New Issue
Block a user