Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -15,7 +15,7 @@ import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .sample_utils import top_p_sampling
@@ -31,12 +31,6 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)
def _get_classes(config: dict):
"""
@@ -188,14 +182,14 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
return y, prob
y, prob = _step(y)
y, p = _step(y)
mx.async_eval(y)
while True:
sync = mx.async_eval(y)
next_out = _step(y)
sync.wait()
yield y.item(), prob
y, prob = next_out
next_y, next_p = _step(y)
mx.async_eval(next_y)
yield y.item(), p
y, p = next_y, next_p
def generate(
@@ -283,6 +277,16 @@ def generate(
return detokenizer.text
def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
"""
Load and initialize the model from a given path.
@@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
config = load_config(model_path)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
@@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():
vocab_size = config["vocab_size"]
extended_linear_class_predicate = (
lambda layer: linear_class_predicate(layer)
and layer.weight.shape[0] != vocab_size
)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=extended_linear_class_predicate,
)
# for models that have lm_head quant
else:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
model,
**quantization,
class_predicate=class_predicate,
)
model.load_weights(list(weights.items()))
@@ -395,10 +385,9 @@ def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path)
return model, config.to_dict(), tokenizer
return model, config, tokenizer
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
@@ -543,10 +532,7 @@ def quantize_model(
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))