mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add Supported Quantized Phi-3-mini-4k-instruct gguf Weight (#717)
* support for phi-3 4bits quantized gguf weights * Added link to 4 bits quantized model * removed some prints * Added correct comment * Added correct comment * removed print Since last condition already prints warning for when quantization is None
This commit is contained in:
@@ -18,6 +18,7 @@ class ModelArgs:
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
context_length: int
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
@@ -157,6 +158,16 @@ class LlamaModel(nn.Module):
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
# model info
|
||||
print(
|
||||
f"Model info\n"
|
||||
f"==========\n"
|
||||
f"Context length: {args.context_length}\n"
|
||||
f"Vocab size: {args.vocab_size}\n"
|
||||
f"Hidden size: {args.hidden_size}\n"
|
||||
f"Num layers: {args.num_hidden_layers}\n"
|
||||
f"Num attention heads: {args.num_attention_heads}\n"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -196,6 +207,7 @@ class Model(nn.Module):
|
||||
|
||||
def get_config(metadata: dict):
|
||||
output = {
|
||||
"context_length": metadata["llama.context_length"],
|
||||
"hidden_size": metadata["llama.embedding_length"],
|
||||
"num_hidden_layers": metadata["llama.block_count"],
|
||||
"num_attention_heads": metadata["llama.attention.head_count"],
|
||||
@@ -269,9 +281,12 @@ def load(gguf_file: str, repo: str = None):
|
||||
elif gguf_ft == 2 or gguf_ft == 3:
|
||||
# MOSTLY_Q4_0 or MOSTLY_Q4_1
|
||||
quantization = {"group_size": 32, "bits": 4}
|
||||
# print bits value
|
||||
print(f"{quantization['bits']} bits quantized model")
|
||||
elif gguf_ft == 7:
|
||||
# MOSTLY_Q8_0 = 7
|
||||
quantization = {"group_size": 32, "bits": 8}
|
||||
print(f"{quantization['bits']} bits quantized model")
|
||||
else:
|
||||
quantization = None
|
||||
print("[WARNING] Using unsupported GGUF quantization. Casting to float16.")
|
||||
|
Reference in New Issue
Block a user