Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -152,47 +152,6 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import qwen2
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
def test_qwen(self):
from mlx_lm.models import qwen
@@ -277,46 +236,6 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
def test_cohere(self):
from mlx_lm.models import cohere