mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -172,7 +172,8 @@ class Model(nn.Module):
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -180,11 +181,15 @@ class Model(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out, cache
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
if self.args.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
|
Reference in New Issue
Block a user