mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -152,47 +152,6 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
|
||||
from mlx_lm.models import qwen2
|
||||
|
||||
args = qwen2.ModelArgs(
|
||||
model_type="qwen2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
tie_word_embeddings=True,
|
||||
)
|
||||
model = qwen2.Model(args)
|
||||
weights = {"model.embed_tokens.weight": "some_value"}
|
||||
sanitized_weights = model.sanitize(weights)
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||
|
||||
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
|
||||
from mlx_lm.models import qwen2
|
||||
|
||||
weights = {
|
||||
"model.embed_tokens.weight": "some_value",
|
||||
"lm_head.weight": "existing_value",
|
||||
}
|
||||
args = qwen2.ModelArgs(
|
||||
model_type="qwen2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
tie_word_embeddings=True,
|
||||
)
|
||||
model = qwen2.Model(args)
|
||||
sanitized_weights = model.sanitize(weights)
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||
|
||||
def test_qwen(self):
|
||||
from mlx_lm.models import qwen
|
||||
|
||||
@@ -277,46 +236,6 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
model_type="starcoder2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
tie_word_embeddings=True,
|
||||
)
|
||||
model = starcoder2.Model(args)
|
||||
weights = {"model.embed_tokens.weight": "some_value"}
|
||||
sanitized_weights = model.sanitize(weights)
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
||||
|
||||
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
model_type="starcoder2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
tie_word_embeddings=True,
|
||||
)
|
||||
model = starcoder2.Model(args)
|
||||
weights = {
|
||||
"model.embed_tokens.weight": "some_value",
|
||||
"lm_head.weight": "existing_value",
|
||||
}
|
||||
|
||||
sanitized_weights = model.sanitize(weights)
|
||||
self.assertIn("lm_head.weight", sanitized_weights)
|
||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
||||
|
||||
def test_cohere(self):
|
||||
from mlx_lm.models import cohere
|
||||
|
||||
|
Reference in New Issue
Block a user