Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -172,7 +172,8 @@ class Model(nn.Module):
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -180,11 +181,15 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k