mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize example (#162)
* testing quantization * conversion + quantization working * one config processor * quantization in mistral / nits in llama * args for quantization * llama / mistral conversion in good shape * phi2 quantized * mixtral * qwen conversion
This commit is contained in:
@@ -178,6 +178,12 @@ class Llama(nn.Module):
|
||||
return self.output(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
@@ -194,7 +200,7 @@ class Llama(nn.Module):
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
y = sample(y)
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
@@ -218,8 +224,7 @@ class Llama(nn.Module):
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
y = sample(self.output(x[:, -1]))
|
||||
|
||||
yield y
|
||||
|
||||
@@ -326,38 +331,46 @@ def few_shot_generate(args):
|
||||
print()
|
||||
|
||||
|
||||
def sanitize_config(config, weights):
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
config.pop(k, None)
|
||||
return config
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
config = sanitize_config(json.loads(f.read()), weights)
|
||||
quantization = config.pop("quantization", None)
|
||||
model = Llama(ModelArgs(**config))
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument(
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
"--model-path",
|
||||
help="Path to the model directory containing the MLX weights",
|
||||
default="mlx_model",
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model. Ignored when --few-shot is provided.",
|
||||
@@ -374,7 +387,7 @@ if __name__ == "__main__":
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
"--temp", type=float, default=0.0, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
@@ -382,9 +395,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user