Add Supported Quantized Phi-3-mini-4k-instruct gguf Weight (#717)

* support for phi-3 4bits quantized gguf weights

* Added link to 4 bits quantized model

* removed some prints

* Added correct comment

* Added correct comment

* removed print

Since last condition already prints warning for when quantization is None
This commit is contained in:
Jaward Sesay 2024-04-30 11:11:32 +08:00 committed by GitHub
parent 5513c4e57d
commit 7c0962f4e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -47,6 +47,10 @@ Models that have been tested and work include:
- [TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF](https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF),
for quantized models use:
- `tinyllama-1.1b-chat-v1.0.Q8_0.gguf`
- `tinyllama-1.1b-chat-v1.0.Q4_0.gguf`
- `tinyllama-1.1b-chat-v1.0.Q4_0.gguf`
- [Jaward/phi-3-mini-4k-instruct.Q4_0.gguf](https://huggingface.co/Jaward/phi-3-mini-4k-instruct.Q4_0.gguf),
for 4 bits quantized phi-3-mini-4k-instruct use:
- `phi-3-mini-4k-instruct.Q4_0.gguf`
[^1]: For more information on GGUF see [the documentation](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md).

View File

@ -18,6 +18,7 @@ class ModelArgs:
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
context_length: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
@ -157,6 +158,16 @@ class LlamaModel(nn.Module):
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
# model info
print(
f"Model info\n"
f"==========\n"
f"Context length: {args.context_length}\n"
f"Vocab size: {args.vocab_size}\n"
f"Hidden size: {args.hidden_size}\n"
f"Num layers: {args.num_hidden_layers}\n"
f"Num attention heads: {args.num_attention_heads}\n"
)
def __call__(
self,
@ -196,6 +207,7 @@ class Model(nn.Module):
def get_config(metadata: dict):
output = {
"context_length": metadata["llama.context_length"],
"hidden_size": metadata["llama.embedding_length"],
"num_hidden_layers": metadata["llama.block_count"],
"num_attention_heads": metadata["llama.attention.head_count"],
@ -269,9 +281,12 @@ def load(gguf_file: str, repo: str = None):
elif gguf_ft == 2 or gguf_ft == 3:
# MOSTLY_Q4_0 or MOSTLY_Q4_1
quantization = {"group_size": 32, "bits": 4}
# print bits value
print(f"{quantization['bits']} bits quantized model")
elif gguf_ft == 7:
# MOSTLY_Q8_0 = 7
quantization = {"group_size": 32, "bits": 8}
print(f"{quantization['bits']} bits quantized model")
else:
quantization = None
print("[WARNING] Using unsupported GGUF quantization. Casting to float16.")