Quantize example (#162)

* testing quantization

* conversion + quantization working

* one config processor

* quantization in mistral / nits in llama

* args for quantization

* llama / mistral conversion in good shape

* phi2 quantized

* mixtral

* qwen conversion
This commit is contained in:
Awni Hannun
2023-12-21 12:59:37 -08:00
committed by GitHub
parent 4c9db80ed2
commit 3cf436b529
17 changed files with 553 additions and 126 deletions

View File

@@ -178,6 +178,12 @@ class Llama(nn.Module):
return self.output(x)
def generate(self, x, temp=1.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
cache = []
# Make an additive causal mask. We will need that to process the prompt.
@@ -194,7 +200,7 @@ class Llama(nn.Module):
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(y)
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
@@ -218,8 +224,7 @@ class Llama(nn.Module):
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
y = sample(self.output(x[:, -1]))
yield y
@@ -326,38 +331,46 @@ def few_shot_generate(args):
print()
def sanitize_config(config, weights):
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
config.pop(k, None)
return config
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
if "rope_theta" not in config:
config["rope_theta"] = 10000
unused = ["multiple_of", "ffn_dim_multiplier"]
for k in unused:
if k in config:
config.pop(k)
config = sanitize_config(json.loads(f.read()), weights)
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
return model
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"model", help="Path to the model directory containing the MLX weights"
"--model-path",
help="Path to the model directory containing the MLX weights",
default="mlx_model",
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument(
"--prompt",
help="The message to be processed by the model. Ignored when --few-shot is provided.",
@@ -374,7 +387,7 @@ if __name__ == "__main__":
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
@@ -382,9 +395,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("[INFO] Loading model from disk.")
model = load_model(args.model)
model, tokenizer = load_model(args.model_path)
if args.few_shot:
few_shot_generate(args)
else: