one config processor

This commit is contained in:
Awni Hannun
2023-12-20 15:00:06 -08:00
parent f6684df32a
commit af4d82c93a
2 changed files with 26 additions and 34 deletions

View File

@@ -118,33 +118,21 @@ def tiny_llama(model_path):
def quantize(weights, config):
import mlx.core as mx
import mlx.nn as nn
from llama import Llama, ModelArgs
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten
quantized_config = copy.deepcopy(config)
# Load the model
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
# Load the model:
config = sanitize_config(config, weights)
model = Llama(ModelArgs(**config))
weights = tree_map(mx.array, weights)
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model)
# Update the config
# Update the config:
quantized_config["quantization"] = {"groups": 64, "width": 4}
quantized_weights = dict(tree_flatten(model.parameters()))

View File

@@ -331,26 +331,30 @@ def few_shot_generate(args):
print()
def sanitize_config(config, weights):
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
return config
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None: