mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
parent
f5f189e48a
commit
2146bcd7ee
@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
|
|||||||
config = get_config(metadata)
|
config = get_config(metadata)
|
||||||
model = Model(ModelArgs(**config))
|
model = Model(ModelArgs(**config))
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
# quantized the LM head?
|
class_predicate = (
|
||||||
qm = model if "lm_head.scales" in weights else model.model
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
nn.QuantizedLinear.quantize_module(
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(
|
||||||
qm,
|
qm,
|
||||||
**quantization,
|
**quantization,
|
||||||
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dequantize(k):
|
|
||||||
weight = weights.pop(f"{k}.weight")
|
|
||||||
scales = weights.pop(f"{k}.scales")
|
|
||||||
biases = weights.pop(f"{k}.biases")
|
|
||||||
weights[f"{k}.weight"] = mx.dequantize(
|
|
||||||
weight, scales=scales, biases=biases, **quantization
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dequantize embeddings
|
|
||||||
dequantize("model.embed_tokens")
|
|
||||||
|
|
||||||
tokenizer = GGUFTokenizer(metadata)
|
tokenizer = GGUFTokenizer(metadata)
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -134,7 +134,7 @@ def quantize(weights, config, args):
|
|||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
|
||||||
# Quantize the model:
|
# Quantize the model:
|
||||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
nn.quantize(model, args.q_group_size, args.q_bits)
|
||||||
|
|
||||||
# Update the config:
|
# Update the config:
|
||||||
quantized_config["quantization"] = {
|
quantized_config["quantization"] = {
|
||||||
|
@ -339,7 +339,7 @@ def load_model(model_path):
|
|||||||
quantization = config.pop("quantization", None)
|
quantization = config.pop("quantization", None)
|
||||||
model = Llama(ModelArgs(**config))
|
model = Llama(ModelArgs(**config))
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.quantize(model, **quantization)
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
|
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.8.0
|
mlx>=0.11.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
|
@ -24,7 +24,7 @@ def quantize(weights, config, args):
|
|||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
|
||||||
# Quantize the model:
|
# Quantize the model:
|
||||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
nn.quantize(model, args.q_group_size, args.q_bits)
|
||||||
|
|
||||||
# Update the config:
|
# Update the config:
|
||||||
quantized_config["quantization"] = {
|
quantized_config["quantization"] = {
|
||||||
|
@ -183,7 +183,7 @@ def load_model(folder: str):
|
|||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
model = Mistral(model_args)
|
model = Mistral(model_args)
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.quantize(model, **quantization)
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.8.0
|
mlx>=0.11.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
|
@ -60,13 +60,10 @@ def quantize(weights, config, args):
|
|||||||
model.update(all_weights)
|
model.update(all_weights)
|
||||||
|
|
||||||
# Quantize the model:
|
# Quantize the model:
|
||||||
nn.QuantizedLinear.quantize_module(
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
args.q_group_size,
|
args.q_group_size,
|
||||||
args.q_bits,
|
args.q_bits,
|
||||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
|
||||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
|
||||||
and m.weight.shape[0] != 8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the subset of quantized weights:
|
# Extract the subset of quantized weights:
|
||||||
|
@ -217,11 +217,7 @@ def load_model(folder: str):
|
|||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
model = Mixtral(model_args)
|
model = Mixtral(model_args)
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
nn.quantize(model, **quantization)
|
||||||
quantization["linear_class_predicate"] = (
|
|
||||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
|
|
||||||
)
|
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
||||||
|
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.8.0
|
mlx>=0.11.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
|
@ -185,7 +185,7 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
out = out @ self.model.embed_tokens.weight.T
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out, cache
|
return out, cache
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
out = out @ self.model.embed_tokens.weight.T
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
return out, cache
|
return out, cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -142,7 +142,7 @@ class Transformer(nn.Module):
|
|||||||
h = self.norm(h)
|
h = self.norm(h)
|
||||||
|
|
||||||
if self.weight_tying:
|
if self.weight_tying:
|
||||||
return h @ self.wte.weight.T, cache
|
return self.wte.as_linear(h), cache
|
||||||
|
|
||||||
return self.ff_out(h), cache
|
return self.ff_out(h), cache
|
||||||
|
|
||||||
|
@ -172,6 +172,7 @@ class Model(nn.Module):
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Qwen2Model(args)
|
self.model = Qwen2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -180,11 +181,15 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out, cache
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
if self.args.tie_word_embeddings:
|
||||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
weights.pop("lm_head.weight", None)
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
@ -149,7 +149,8 @@ class Model(nn.Module):
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Starcoder2Model(args)
|
self.model = Starcoder2Model(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
if not args.tie_word_embeddings:
|
||||||
|
sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -157,12 +158,11 @@ class Model(nn.Module):
|
|||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out, cache = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
def sanitize(self, weights):
|
else:
|
||||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
out = self.lm_head(out)
|
||||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
return out, cache
|
||||||
return weights
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.10
|
mlx>=0.11
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.39.3
|
transformers>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -74,6 +74,7 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
self._tokenizer.decode([0])
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -79,6 +79,11 @@ def default_loss(model, inputs, targets, lengths):
|
|||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
|
if len(dataset) < batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset must have at least batch_size={batch_size}"
|
||||||
|
f" examples but only has {len(dataset)}."
|
||||||
|
)
|
||||||
|
|
||||||
# Make the batches:
|
# Make the batches:
|
||||||
batch_idx = [
|
batch_idx = [
|
||||||
|
@ -15,7 +15,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .sample_utils import top_p_sampling
|
from .sample_utils import top_p_sampling
|
||||||
@ -31,12 +31,6 @@ MODEL_REMAPPING = {
|
|||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
linear_class_predicate = (
|
|
||||||
lambda m: isinstance(m, nn.Linear)
|
|
||||||
and m.weight.shape[0]
|
|
||||||
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
"""
|
"""
|
||||||
@ -188,14 +182,14 @@ def generate_step(
|
|||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
return y, prob
|
return y, prob
|
||||||
|
|
||||||
y, prob = _step(y)
|
y, p = _step(y)
|
||||||
|
|
||||||
|
mx.async_eval(y)
|
||||||
while True:
|
while True:
|
||||||
sync = mx.async_eval(y)
|
next_y, next_p = _step(y)
|
||||||
next_out = _step(y)
|
mx.async_eval(next_y)
|
||||||
sync.wait()
|
yield y.item(), p
|
||||||
yield y.item(), prob
|
y, p = next_y, next_p
|
||||||
y, prob = next_out
|
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -283,6 +277,16 @@ def generate(
|
|||||||
return detokenizer.text
|
return detokenizer.text
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(model_path: Path) -> dict:
|
||||||
|
try:
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error(f"Config file not found in {model_path}")
|
||||||
|
raise
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load and initialize the model from a given path.
|
Load and initialize the model from a given path.
|
||||||
@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
with open(model_path / "config.json", "r") as f:
|
config = load_config(model_path)
|
||||||
config = json.load(f)
|
|
||||||
quantization = config.get("quantization", None)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error(f"Config file not found in {model_path}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||||
if not weight_files:
|
if not weight_files:
|
||||||
@ -325,25 +324,16 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
|||||||
if hasattr(model, "sanitize"):
|
if hasattr(model, "sanitize"):
|
||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
if quantization is not None:
|
if (quantization := config.get("quantization", None)) is not None:
|
||||||
# for legacy models that don't have lm_head quant due to non-32 dims
|
# Handle legacy models which may not have everything quantized
|
||||||
if "lm_head.scales" not in weights.keys():
|
class_predicate = (
|
||||||
vocab_size = config["vocab_size"]
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
extended_linear_class_predicate = (
|
and f"{p}.scales" in weights
|
||||||
lambda layer: linear_class_predicate(layer)
|
|
||||||
and layer.weight.shape[0] != vocab_size
|
|
||||||
)
|
)
|
||||||
nn.QuantizedLinear.quantize_module(
|
nn.quantize(
|
||||||
model,
|
model,
|
||||||
**quantization,
|
**quantization,
|
||||||
linear_class_predicate=extended_linear_class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
|
||||||
# for models that have lm_head quant
|
|
||||||
else:
|
|
||||||
nn.QuantizedLinear.quantize_module(
|
|
||||||
model,
|
|
||||||
**quantization,
|
|
||||||
linear_class_predicate=linear_class_predicate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()))
|
||||||
@ -395,10 +385,9 @@ def fetch_from_hub(
|
|||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model = load_model(model_path, lazy)
|
model = load_model(model_path, lazy)
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = load_config(model_path)
|
||||||
tokenizer = load_tokenizer(model_path)
|
tokenizer = load_tokenizer(model_path)
|
||||||
|
return model, config, tokenizer
|
||||||
return model, config.to_dict(), tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||||
@ -543,10 +532,7 @@ def quantize_model(
|
|||||||
Tuple: Tuple containing quantized weights and config.
|
Tuple: Tuple containing quantized weights and config.
|
||||||
"""
|
"""
|
||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
|
nn.quantize(model, q_group_size, q_bits)
|
||||||
nn.QuantizedLinear.quantize_module(
|
|
||||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
|
||||||
)
|
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.9.0"
|
__version__ = "0.10.0"
|
||||||
|
@ -152,47 +152,6 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
|
|
||||||
from mlx_lm.models import qwen2
|
|
||||||
|
|
||||||
args = qwen2.ModelArgs(
|
|
||||||
model_type="qwen2",
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=4,
|
|
||||||
intermediate_size=2048,
|
|
||||||
num_attention_heads=4,
|
|
||||||
rms_norm_eps=1e-5,
|
|
||||||
vocab_size=10_000,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
)
|
|
||||||
model = qwen2.Model(args)
|
|
||||||
weights = {"model.embed_tokens.weight": "some_value"}
|
|
||||||
sanitized_weights = model.sanitize(weights)
|
|
||||||
self.assertIn("lm_head.weight", sanitized_weights)
|
|
||||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
|
||||||
|
|
||||||
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
|
|
||||||
from mlx_lm.models import qwen2
|
|
||||||
|
|
||||||
weights = {
|
|
||||||
"model.embed_tokens.weight": "some_value",
|
|
||||||
"lm_head.weight": "existing_value",
|
|
||||||
}
|
|
||||||
args = qwen2.ModelArgs(
|
|
||||||
model_type="qwen2",
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=4,
|
|
||||||
intermediate_size=2048,
|
|
||||||
num_attention_heads=4,
|
|
||||||
rms_norm_eps=1e-5,
|
|
||||||
vocab_size=10_000,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
)
|
|
||||||
model = qwen2.Model(args)
|
|
||||||
sanitized_weights = model.sanitize(weights)
|
|
||||||
self.assertIn("lm_head.weight", sanitized_weights)
|
|
||||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
|
||||||
|
|
||||||
def test_qwen(self):
|
def test_qwen(self):
|
||||||
from mlx_lm.models import qwen
|
from mlx_lm.models import qwen
|
||||||
|
|
||||||
@ -277,46 +236,6 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
|
|
||||||
from mlx_lm.models import starcoder2
|
|
||||||
|
|
||||||
args = starcoder2.ModelArgs(
|
|
||||||
model_type="starcoder2",
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=4,
|
|
||||||
intermediate_size=2048,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=4,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
)
|
|
||||||
model = starcoder2.Model(args)
|
|
||||||
weights = {"model.embed_tokens.weight": "some_value"}
|
|
||||||
sanitized_weights = model.sanitize(weights)
|
|
||||||
self.assertIn("lm_head.weight", sanitized_weights)
|
|
||||||
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
|
||||||
|
|
||||||
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
|
||||||
from mlx_lm.models import starcoder2
|
|
||||||
|
|
||||||
args = starcoder2.ModelArgs(
|
|
||||||
model_type="starcoder2",
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=4,
|
|
||||||
intermediate_size=2048,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=4,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
)
|
|
||||||
model = starcoder2.Model(args)
|
|
||||||
weights = {
|
|
||||||
"model.embed_tokens.weight": "some_value",
|
|
||||||
"lm_head.weight": "existing_value",
|
|
||||||
}
|
|
||||||
|
|
||||||
sanitized_weights = model.sanitize(weights)
|
|
||||||
self.assertIn("lm_head.weight", sanitized_weights)
|
|
||||||
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
|
||||||
|
|
||||||
def test_cohere(self):
|
def test_cohere(self):
|
||||||
from mlx_lm.models import cohere
|
from mlx_lm.models import cohere
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import argparse
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -34,10 +35,18 @@ if __name__ == "__main__":
|
|||||||
# Load the models
|
# Load the models
|
||||||
if args.model == "sdxl":
|
if args.model == "sdxl":
|
||||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
|
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
nn.quantize(
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
)
|
||||||
|
nn.quantize(
|
||||||
|
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.quantize(sd.text_encoder_1)
|
||||||
|
nn.quantize(sd.text_encoder_2)
|
||||||
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 0.0
|
args.cfg = args.cfg or 0.0
|
||||||
args.steps = args.steps or 2
|
args.steps = args.steps or 2
|
||||||
else:
|
else:
|
||||||
@ -45,8 +54,10 @@ if __name__ == "__main__":
|
|||||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||||
)
|
)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
nn.quantize(
|
||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
|
)
|
||||||
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 7.5
|
args.cfg = args.cfg or 7.5
|
||||||
args.steps = args.steps or 50
|
args.steps = args.steps or 50
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.6
|
mlx>=0.11
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
regex
|
regex
|
||||||
numpy
|
numpy
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn import QuantizedLinear
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -34,9 +34,13 @@ if __name__ == "__main__":
|
|||||||
if args.model == "sdxl":
|
if args.model == "sdxl":
|
||||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
nn.quantize(
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
)
|
||||||
|
nn.quantize(
|
||||||
|
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
|
)
|
||||||
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 0.0
|
args.cfg = args.cfg or 0.0
|
||||||
args.steps = args.steps or 2
|
args.steps = args.steps or 2
|
||||||
else:
|
else:
|
||||||
@ -44,8 +48,10 @@ if __name__ == "__main__":
|
|||||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||||
)
|
)
|
||||||
if args.quantize:
|
if args.quantize:
|
||||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
nn.quantize(
|
||||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||||
|
)
|
||||||
|
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||||
args.cfg = args.cfg or 7.5
|
args.cfg = args.cfg or 7.5
|
||||||
args.steps = args.steps or 50
|
args.steps = args.steps or 50
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ def quantize(weights, config, args):
|
|||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
|
|
||||||
# Quantize the model:
|
# Quantize the model:
|
||||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
nn.quantize(model, args.q_group_size, args.q_bits)
|
||||||
|
|
||||||
# Update the config:
|
# Update the config:
|
||||||
quantized_config["quantization"] = {
|
quantized_config["quantization"] = {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.8
|
mlx>=0.11
|
||||||
numba
|
numba
|
||||||
numpy
|
numpy
|
||||||
torch
|
torch
|
||||||
|
@ -32,7 +32,7 @@ def load_model(
|
|||||||
model = whisper.Whisper(model_args, dtype)
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.quantize(model, **quantization)
|
||||||
|
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
@ -196,7 +196,7 @@ class TextDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
|
return self.token_embedding.as_linear(x), kv_cache, cross_qk
|
||||||
|
|
||||||
|
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user