Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun 2024-04-18 18:16:10 -07:00 committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 108 additions and 190 deletions

View File

@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
config = get_config(metadata) config = get_config(metadata)
model = Model(ModelArgs(**config)) model = Model(ModelArgs(**config))
if quantization is not None: if quantization is not None:
# quantized the LM head? class_predicate = (
qm = model if "lm_head.scales" in weights else model.model lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
nn.QuantizedLinear.quantize_module( and f"{p}.scales" in weights
)
nn.quantize(
qm, qm,
**quantization, **quantization,
class_predicate=class_predicate,
) )
def dequantize(k):
weight = weights.pop(f"{k}.weight")
scales = weights.pop(f"{k}.scales")
biases = weights.pop(f"{k}.biases")
weights[f"{k}.weight"] = mx.dequantize(
weight, scales=scales, biases=biases, **quantization
)
# Dequantize embeddings
dequantize("model.embed_tokens")
tokenizer = GGUFTokenizer(metadata) tokenizer = GGUFTokenizer(metadata)
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
return model, tokenizer return model, tokenizer

View File

@ -134,7 +134,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.quantize(model, args.q_group_size, args.q_bits)
# Update the config: # Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {

View File

@ -339,7 +339,7 @@ def load_model(model_path):
quantization = config.pop("quantization", None) quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config)) model = Llama(ModelArgs(**config))
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.quantize(model, **quantization)
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer return model, tokenizer

View File

@ -1,4 +1,4 @@
mlx>=0.8.0 mlx>=0.11.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -24,7 +24,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.quantize(model, args.q_group_size, args.q_bits)
# Update the config: # Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {

View File

@ -183,7 +183,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
model = Mistral(model_args) model = Mistral(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.quantize(model, **quantization)
model.update(weights) model.update(weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model, tokenizer return model, tokenizer

View File

@ -1,4 +1,4 @@
mlx>=0.8.0 mlx>=0.11.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -60,13 +60,10 @@ def quantize(weights, config, args):
model.update(all_weights) model.update(all_weights)
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module( nn.quantize(
model, model,
args.q_group_size, args.q_group_size,
args.q_bits, args.q_bits,
# TODO: Quantize gate matrices when < 32 tiles supported
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
) )
# Extract the subset of quantized weights: # Extract the subset of quantized weights:

View File

@ -217,11 +217,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
model = Mixtral(model_args) model = Mixtral(model_args)
if quantization is not None: if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported nn.quantize(model, **quantization)
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights) model.update(weights)
return model, tokenizer return model, tokenizer

View File

@ -1,4 +1,4 @@
mlx>=0.8.0 mlx>=0.11.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -185,7 +185,7 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale out = out * self.model.args.logit_scale
return out, cache return out, cache

View File

@ -169,7 +169,7 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T out = self.model.embed_tokens.as_linear(out)
return out, cache return out, cache
@property @property

View File

@ -142,7 +142,7 @@ class Transformer(nn.Module):
h = self.norm(h) h = self.norm(h)
if self.weight_tying: if self.weight_tying:
return h @ self.wte.weight.T, cache return self.wte.as_linear(h), cache
return self.ff_out(h), cache return self.ff_out(h), cache

View File

@ -172,7 +172,8 @@ class Model(nn.Module):
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.model = Qwen2Model(args) self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
@ -180,11 +181,15 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
return self.lm_head(out), cache if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
def sanitize(self, weights): def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights: if self.args.tie_word_embeddings:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"] weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs
return { return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k

View File

@ -149,7 +149,8 @@ class Model(nn.Module):
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.model = Starcoder2Model(args) self.model = Starcoder2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) if not args.tie_word_embeddings:
sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(
self, self,
@ -157,12 +158,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
return self.lm_head(out), cache if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
def sanitize(self, weights): else:
if self.args.tie_word_embeddings and "lm_head.weight" not in weights: out = self.lm_head(out)
weights["lm_head.weight"] = weights["model.embed_tokens.weight"] return out, cache
return weights
@property @property
def layers(self): def layers(self):

View File

@ -1,4 +1,4 @@
mlx>=0.10 mlx>=0.11
numpy numpy
transformers>=4.39.3 transformers>=4.39.3
protobuf protobuf

View File

@ -74,6 +74,7 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer): def __init__(self, tokenizer):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._tokenizer.decode([0])
self.reset() self.reset()
def reset(self): def reset(self):

View File

@ -79,6 +79,11 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
# Make the batches: # Make the batches:
batch_idx = [ batch_idx = [

View File

@ -15,7 +15,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports # Local imports
from .sample_utils import top_p_sampling from .sample_utils import top_p_sampling
@ -31,12 +31,6 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)
def _get_classes(config: dict): def _get_classes(config: dict):
""" """
@ -188,14 +182,14 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:] repetition_context = repetition_context[-repetition_context_size:]
return y, prob return y, prob
y, prob = _step(y) y, p = _step(y)
mx.async_eval(y)
while True: while True:
sync = mx.async_eval(y) next_y, next_p = _step(y)
next_out = _step(y) mx.async_eval(next_y)
sync.wait() yield y.item(), p
yield y.item(), prob y, p = next_y, next_p
y, prob = next_out
def generate( def generate(
@ -283,6 +277,16 @@ def generate(
return detokenizer.text return detokenizer.text
def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config
def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
""" """
Load and initialize the model from a given path. Load and initialize the model from a given path.
@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
FileNotFoundError: If the weight files (.safetensors) are not found. FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated. ValueError: If the model class or args class are not found or cannot be instantiated.
""" """
try:
with open(model_path / "config.json", "r") as f: config = load_config(model_path)
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
weight_files = glob.glob(str(model_path / "*.safetensors")) weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files: if not weight_files:
@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
if hasattr(model, "sanitize"): if hasattr(model, "sanitize"):
weights = model.sanitize(weights) weights = model.sanitize(weights)
if quantization is not None: if (quantization := config.get("quantization", None)) is not None:
# for legacy models that don't have lm_head quant due to non-32 dims # Handle legacy models which may not have everything quantized
if "lm_head.scales" not in weights.keys(): class_predicate = (
vocab_size = config["vocab_size"] lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
extended_linear_class_predicate = ( and f"{p}.scales" in weights
lambda layer: linear_class_predicate(layer) )
and layer.weight.shape[0] != vocab_size nn.quantize(
) model,
nn.QuantizedLinear.quantize_module( **quantization,
model, class_predicate=class_predicate,
**quantization, )
linear_class_predicate=extended_linear_class_predicate,
)
# for models that have lm_head quant
else:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
@ -395,10 +385,9 @@ def fetch_from_hub(
model_path: Path, lazy: bool = False model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy) model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path) config = load_config(model_path)
tokenizer = load_tokenizer(model_path) tokenizer = load_tokenizer(model_path)
return model, config, tokenizer
return model, config.to_dict(), tokenizer
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
@ -543,10 +532,7 @@ def quantize_model(
Tuple: Tuple containing quantized weights and config. Tuple: Tuple containing quantized weights and config.
""" """
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters())) quantized_weights = dict(tree_flatten(model.parameters()))

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.9.0" __version__ = "0.10.0"

View File

@ -152,47 +152,6 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import qwen2
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=True,
)
model = qwen2.Model(args)
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
def test_qwen(self): def test_qwen(self):
from mlx_lm.models import qwen from mlx_lm.models import qwen
@ -277,46 +236,6 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {"model.embed_tokens.weight": "some_value"}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=True,
)
model = starcoder2.Model(args)
weights = {
"model.embed_tokens.weight": "some_value",
"lm_head.weight": "existing_value",
}
sanitized_weights = model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized_weights)
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
def test_cohere(self): def test_cohere(self):
from mlx_lm.models import cohere from mlx_lm.models import cohere

View File

@ -4,6 +4,7 @@ import argparse
import math import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@ -34,10 +35,18 @@ if __name__ == "__main__":
# Load the models # Load the models
if args.model == "sdxl": if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize: if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1) nn.quantize(
QuantizedLinear.quantize_module(sd.text_encoder_2) sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) )
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.text_encoder_1)
nn.quantize(sd.text_encoder_2)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0 args.cfg = args.cfg or 0.0
args.steps = args.steps or 2 args.steps = args.steps or 2
else: else:
@ -45,8 +54,10 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base", float16=args.float16 "stabilityai/stable-diffusion-2-1-base", float16=args.float16
) )
if args.quantize: if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder) nn.quantize(
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5 args.cfg = args.cfg or 7.5
args.steps = args.steps or 50 args.steps = args.steps or 50

View File

@ -1,4 +1,4 @@
mlx>=0.6 mlx>=0.11
huggingface-hub huggingface-hub
regex regex
numpy numpy

View File

@ -3,8 +3,8 @@
import argparse import argparse
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn import QuantizedLinear
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@ -34,9 +34,13 @@ if __name__ == "__main__":
if args.model == "sdxl": if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize: if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1) nn.quantize(
QuantizedLinear.quantize_module(sd.text_encoder_2) sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) )
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0 args.cfg = args.cfg or 0.0
args.steps = args.steps or 2 args.steps = args.steps or 2
else: else:
@ -44,8 +48,10 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base", float16=args.float16 "stabilityai/stable-diffusion-2-1-base", float16=args.float16
) )
if args.quantize: if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder) nn.quantize(
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8) sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5 args.cfg = args.cfg or 7.5
args.steps = args.steps or 50 args.steps = args.steps or 50

View File

@ -254,7 +254,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
# Quantize the model: # Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) nn.quantize(model, args.q_group_size, args.q_bits)
# Update the config: # Update the config:
quantized_config["quantization"] = { quantized_config["quantization"] = {

View File

@ -1,4 +1,4 @@
mlx>=0.8 mlx>=0.11
numba numba
numpy numpy
torch torch

View File

@ -32,7 +32,7 @@ def load_model(
model = whisper.Whisper(model_args, dtype) model = whisper.Whisper(model_args, dtype)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.quantize(model, **quantization)
model.update(weights) model.update(weights)
mx.eval(model.parameters()) mx.eval(model.parameters())

View File

@ -196,7 +196,7 @@ class TextDecoder(nn.Module):
) )
x = self.ln(x) x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache, cross_qk return self.token_embedding.as_linear(x), kv_cache, cross_qk
class Whisper(nn.Module): class Whisper(nn.Module):