Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -185,7 +185,7 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out, cache

View File

@@ -169,7 +169,7 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
return out, cache
@property

View File

@@ -142,7 +142,7 @@ class Transformer(nn.Module):
h = self.norm(h)
if self.weight_tying:
return h @ self.wte.weight.T, cache
return self.wte.as_linear(h), cache
return self.ff_out(h), cache

View File

@@ -172,7 +172,8 @@ class Model(nn.Module):
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -180,11 +181,15 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@@ -149,7 +149,8 @@ class Model(nn.Module):
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
@@ -157,12 +158,11 @@ class Model(nn.Module):
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
@property
def layers(self):