From 2146bcd7ee76ce7ac46f585801648087726ad904 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 18 Apr 2024 18:16:10 -0700 Subject: [PATCH] Quantize embedding / Update quantize API (#680) * more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test --- llms/gguf_llm/models.py | 20 +++----- llms/llama/convert.py | 2 +- llms/llama/llama.py | 2 +- llms/llama/requirements.txt | 2 +- llms/mistral/convert.py | 2 +- llms/mistral/mistral.py | 2 +- llms/mistral/requirements.txt | 2 +- llms/mixtral/convert.py | 5 +- llms/mixtral/mixtral.py | 6 +-- llms/mixtral/requirements.txt | 2 +- llms/mlx_lm/models/cohere.py | 2 +- llms/mlx_lm/models/gemma.py | 2 +- llms/mlx_lm/models/olmo.py | 2 +- llms/mlx_lm/models/qwen2.py | 13 +++-- llms/mlx_lm/models/starcoder2.py | 14 +++--- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/tokenizer_utils.py | 1 + llms/mlx_lm/tuner/trainer.py | 5 ++ llms/mlx_lm/utils.py | 80 +++++++++++++----------------- llms/mlx_lm/version.py | 2 +- llms/tests/test_models.py | 81 ------------------------------- stable_diffusion/image2image.py | 21 ++++++-- stable_diffusion/requirements.txt | 2 +- stable_diffusion/txt2image.py | 18 ++++--- whisper/convert.py | 2 +- whisper/requirements.txt | 2 +- whisper/whisper/load_models.py | 2 +- whisper/whisper/whisper.py | 2 +- 28 files changed, 108 insertions(+), 190 deletions(-) diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index 4ffbb3fe..b0d07558 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None): config = get_config(metadata) model = Model(ModelArgs(**config)) if quantization is not None: - # quantized the LM head? - qm = model if "lm_head.scales" in weights else model.model - nn.QuantizedLinear.quantize_module( + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize( qm, **quantization, + class_predicate=class_predicate, ) - def dequantize(k): - weight = weights.pop(f"{k}.weight") - scales = weights.pop(f"{k}.scales") - biases = weights.pop(f"{k}.biases") - weights[f"{k}.weight"] = mx.dequantize( - weight, scales=scales, biases=biases, **quantization - ) - - # Dequantize embeddings - dequantize("model.embed_tokens") - tokenizer = GGUFTokenizer(metadata) model.load_weights(list(weights.items())) return model, tokenizer diff --git a/llms/llama/convert.py b/llms/llama/convert.py index 6c9dcea4..04c10a5f 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -134,7 +134,7 @@ def quantize(weights, config, args): model.update(tree_unflatten(list(weights.items()))) # Quantize the model: - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + nn.quantize(model, args.q_group_size, args.q_bits) # Update the config: quantized_config["quantization"] = { diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 3e5f78a1..b791a5c2 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -339,7 +339,7 @@ def load_model(model_path): quantization = config.pop("quantization", None) model = Llama(ModelArgs(**config)) if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) + nn.quantize(model, **quantization) model.update(tree_unflatten(list(weights.items()))) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) return model, tokenizer diff --git a/llms/llama/requirements.txt b/llms/llama/requirements.txt index 6b458abc..e67f2167 100644 --- a/llms/llama/requirements.txt +++ b/llms/llama/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.8.0 +mlx>=0.11.0 sentencepiece torch numpy diff --git a/llms/mistral/convert.py b/llms/mistral/convert.py index 2aae4fc6..56096f76 100644 --- a/llms/mistral/convert.py +++ b/llms/mistral/convert.py @@ -24,7 +24,7 @@ def quantize(weights, config, args): model.update(tree_unflatten(list(weights.items()))) # Quantize the model: - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + nn.quantize(model, args.q_group_size, args.q_bits) # Update the config: quantized_config["quantization"] = { diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 24ae730d..d68b8c8a 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -183,7 +183,7 @@ def load_model(folder: str): weights = tree_unflatten(list(weights.items())) model = Mistral(model_args) if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) + nn.quantize(model, **quantization) model.update(weights) mx.eval(model.parameters()) return model, tokenizer diff --git a/llms/mistral/requirements.txt b/llms/mistral/requirements.txt index 6b458abc..e67f2167 100644 --- a/llms/mistral/requirements.txt +++ b/llms/mistral/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.8.0 +mlx>=0.11.0 sentencepiece torch numpy diff --git a/llms/mixtral/convert.py b/llms/mixtral/convert.py index 52ba1ea6..ac2aedbb 100644 --- a/llms/mixtral/convert.py +++ b/llms/mixtral/convert.py @@ -60,13 +60,10 @@ def quantize(weights, config, args): model.update(all_weights) # Quantize the model: - nn.QuantizedLinear.quantize_module( + nn.quantize( model, args.q_group_size, args.q_bits, - # TODO: Quantize gate matrices when < 32 tiles supported - linear_class_predicate=lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] != 8, ) # Extract the subset of quantized weights: diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 67486e84..4b45d066 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -217,11 +217,7 @@ def load_model(folder: str): weights = tree_unflatten(list(weights.items())) model = Mixtral(model_args) if quantization is not None: - # TODO: Quantize gate matrices when < 32 tiles supported - quantization["linear_class_predicate"] = ( - lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8 - ) - nn.QuantizedLinear.quantize_module(model, **quantization) + nn.quantize(model, **quantization) model.update(weights) return model, tokenizer diff --git a/llms/mixtral/requirements.txt b/llms/mixtral/requirements.txt index 6b458abc..e67f2167 100644 --- a/llms/mixtral/requirements.txt +++ b/llms/mixtral/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.8.0 +mlx>=0.11.0 sentencepiece torch numpy diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 9de0599d..dae61760 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -185,7 +185,7 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - out = out @ self.model.embed_tokens.weight.T + out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out, cache diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index fa6cab9e..ebd8f5e7 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -169,7 +169,7 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - out = out @ self.model.embed_tokens.weight.T + out = self.model.embed_tokens.as_linear(out) return out, cache @property diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 541735de..b84b2a38 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -142,7 +142,7 @@ class Transformer(nn.Module): h = self.norm(h) if self.weight_tying: - return h @ self.wte.weight.T, cache + return self.wte.as_linear(h), cache return self.ff_out(h), cache diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 1e694b20..d95893f9 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -172,7 +172,8 @@ class Model(nn.Module): self.args = args self.model_type = args.model_type self.model = Qwen2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -180,11 +181,15 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out, cache def sanitize(self, weights): - if self.args.tie_word_embeddings and "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index f18160a5..2637a35a 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -149,7 +149,8 @@ class Model(nn.Module): self.args = args self.model_type = args.model_type self.model = Starcoder2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, @@ -157,12 +158,11 @@ class Model(nn.Module): cache=None, ): out, cache = self.model(inputs, cache) - return self.lm_head(out), cache - - def sanitize(self, weights): - if self.args.tie_word_embeddings and "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - return weights + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out, cache @property def layers(self): diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 85dfaa53..4e9ab42d 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.10 +mlx>=0.11 numpy transformers>=4.39.3 protobuf diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 15b9963e..bfc7bde1 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -74,6 +74,7 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer): self._tokenizer = tokenizer + self._tokenizer.decode([0]) self.reset() def reset(self): diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 2ed7a646..1408cd8f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -79,6 +79,11 @@ def default_loss(model, inputs, targets, lengths): def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) # Make the batches: batch_idx = [ diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index be273e67..e38a0277 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -15,7 +15,7 @@ import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx.utils import tree_flatten -from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer # Local imports from .sample_utils import top_p_sampling @@ -31,12 +31,6 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 -linear_class_predicate = ( - lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] - != 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models -) - def _get_classes(config: dict): """ @@ -188,14 +182,14 @@ def generate_step( repetition_context = repetition_context[-repetition_context_size:] return y, prob - y, prob = _step(y) + y, p = _step(y) + mx.async_eval(y) while True: - sync = mx.async_eval(y) - next_out = _step(y) - sync.wait() - yield y.item(), prob - y, prob = next_out + next_y, next_p = _step(y) + mx.async_eval(next_y) + yield y.item(), p + y, p = next_y, next_p def generate( @@ -283,6 +277,16 @@ def generate( return detokenizer.text +def load_config(model_path: Path) -> dict: + try: + with open(model_path / "config.json", "r") as f: + config = json.load(f) + except FileNotFoundError: + logging.error(f"Config file not found in {model_path}") + raise + return config + + def load_model(model_path: Path, lazy: bool = False) -> nn.Module: """ Load and initialize the model from a given path. @@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - try: - with open(model_path / "config.json", "r") as f: - config = json.load(f) - quantization = config.get("quantization", None) - except FileNotFoundError: - logging.error(f"Config file not found in {model_path}") - raise + + config = load_config(model_path) weight_files = glob.glob(str(model_path / "*.safetensors")) if not weight_files: @@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: if hasattr(model, "sanitize"): weights = model.sanitize(weights) - if quantization is not None: - # for legacy models that don't have lm_head quant due to non-32 dims - if "lm_head.scales" not in weights.keys(): - vocab_size = config["vocab_size"] - extended_linear_class_predicate = ( - lambda layer: linear_class_predicate(layer) - and layer.weight.shape[0] != vocab_size - ) - nn.QuantizedLinear.quantize_module( - model, - **quantization, - linear_class_predicate=extended_linear_class_predicate, - ) - # for models that have lm_head quant - else: - nn.QuantizedLinear.quantize_module( - model, - **quantization, - linear_class_predicate=linear_class_predicate, - ) + if (quantization := config.get("quantization", None)) is not None: + # Handle legacy models which may not have everything quantized + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize( + model, + **quantization, + class_predicate=class_predicate, + ) model.load_weights(list(weights.items())) @@ -395,10 +385,9 @@ def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) - config = AutoConfig.from_pretrained(model_path) + config = load_config(model_path) tokenizer = load_tokenizer(model_path) - - return model, config.to_dict(), tokenizer + return model, config, tokenizer def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: @@ -543,10 +532,7 @@ def quantize_model( Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - - nn.QuantizedLinear.quantize_module( - model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate - ) + nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_weights = dict(tree_flatten(model.parameters())) diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 1e0aa30d..f907220b 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.9.0" +__version__ = "0.10.0" diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 57fab58d..41565934 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -152,47 +152,6 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) - def test_qwen2_tie_word_embeddings_without_lm_head_weight(self): - from mlx_lm.models import qwen2 - - args = qwen2.ModelArgs( - model_type="qwen2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - tie_word_embeddings=True, - ) - model = qwen2.Model(args) - weights = {"model.embed_tokens.weight": "some_value"} - sanitized_weights = model.sanitize(weights) - self.assertIn("lm_head.weight", sanitized_weights) - self.assertEqual(sanitized_weights["lm_head.weight"], "some_value") - - def test_qwen2_tie_word_embeddings_with_lm_head_weight(self): - from mlx_lm.models import qwen2 - - weights = { - "model.embed_tokens.weight": "some_value", - "lm_head.weight": "existing_value", - } - args = qwen2.ModelArgs( - model_type="qwen2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - tie_word_embeddings=True, - ) - model = qwen2.Model(args) - sanitized_weights = model.sanitize(weights) - self.assertIn("lm_head.weight", sanitized_weights) - self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value") - def test_qwen(self): from mlx_lm.models import qwen @@ -277,46 +236,6 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) - def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self): - from mlx_lm.models import starcoder2 - - args = starcoder2.ModelArgs( - model_type="starcoder2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - num_key_value_heads=4, - tie_word_embeddings=True, - ) - model = starcoder2.Model(args) - weights = {"model.embed_tokens.weight": "some_value"} - sanitized_weights = model.sanitize(weights) - self.assertIn("lm_head.weight", sanitized_weights) - self.assertEqual(sanitized_weights["lm_head.weight"], "some_value") - - def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self): - from mlx_lm.models import starcoder2 - - args = starcoder2.ModelArgs( - model_type="starcoder2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - num_key_value_heads=4, - tie_word_embeddings=True, - ) - model = starcoder2.Model(args) - weights = { - "model.embed_tokens.weight": "some_value", - "lm_head.weight": "existing_value", - } - - sanitized_weights = model.sanitize(weights) - self.assertIn("lm_head.weight", sanitized_weights) - self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value") - def test_cohere(self): from mlx_lm.models import cohere diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index 802dee57..e470aa81 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -4,6 +4,7 @@ import argparse import math import mlx.core as mx +import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm @@ -34,10 +35,18 @@ if __name__ == "__main__": # Load the models if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) + if args.quantize: - QuantizedLinear.quantize_module(sd.text_encoder_1) - QuantizedLinear.quantize_module(sd.text_encoder_2) - QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + nn.quantize( + sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + nn.quantize( + sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + + nn.quantize(sd.text_encoder_1) + nn.quantize(sd.text_encoder_2) + nn.quantize(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 0.0 args.steps = args.steps or 2 else: @@ -45,8 +54,10 @@ if __name__ == "__main__": "stabilityai/stable-diffusion-2-1-base", float16=args.float16 ) if args.quantize: - QuantizedLinear.quantize_module(sd.text_encoder) - QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + nn.quantize( + sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + nn.quantize(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index ab85c726..d02baf7a 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.6 +mlx>=0.11 huggingface-hub regex numpy diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 1566bf6b..26c757f8 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -3,8 +3,8 @@ import argparse import mlx.core as mx +import mlx.nn as nn import numpy as np -from mlx.nn import QuantizedLinear from PIL import Image from tqdm import tqdm @@ -34,9 +34,13 @@ if __name__ == "__main__": if args.model == "sdxl": sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) if args.quantize: - QuantizedLinear.quantize_module(sd.text_encoder_1) - QuantizedLinear.quantize_module(sd.text_encoder_2) - QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + nn.quantize( + sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + nn.quantize( + sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + nn.quantize(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 0.0 args.steps = args.steps or 2 else: @@ -44,8 +48,10 @@ if __name__ == "__main__": "stabilityai/stable-diffusion-2-1-base", float16=args.float16 ) if args.quantize: - QuantizedLinear.quantize_module(sd.text_encoder) - QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) + nn.quantize( + sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) + ) + nn.quantize(sd.unet, group_size=32, bits=8) args.cfg = args.cfg or 7.5 args.steps = args.steps or 50 diff --git a/whisper/convert.py b/whisper/convert.py index 824b0986..fd208184 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -254,7 +254,7 @@ def quantize(weights, config, args): model.update(tree_unflatten(list(weights.items()))) # Quantize the model: - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + nn.quantize(model, args.q_group_size, args.q_bits) # Update the config: quantized_config["quantization"] = { diff --git a/whisper/requirements.txt b/whisper/requirements.txt index 62f55737..cf9c92aa 100644 --- a/whisper/requirements.txt +++ b/whisper/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.8 +mlx>=0.11 numba numpy torch diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index e2e567a3..2b7efaf0 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -32,7 +32,7 @@ def load_model( model = whisper.Whisper(model_args, dtype) if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) + nn.quantize(model, **quantization) model.update(weights) mx.eval(model.parameters()) diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index f5cc3888..e691792c 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -196,7 +196,7 @@ class TextDecoder(nn.Module): ) x = self.ln(x) - return x @ self.token_embedding.weight.T, kv_cache, cross_qk + return self.token_embedding.as_linear(x), kv_cache, cross_qk class Whisper(nn.Module):