mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -15,7 +15,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .sample_utils import top_p_sampling
|
||||
@@ -31,12 +31,6 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
linear_class_predicate = (
|
||||
lambda m: isinstance(m, nn.Linear)
|
||||
and m.weight.shape[0]
|
||||
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||
)
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
@@ -188,14 +182,14 @@ def generate_step(
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, prob
|
||||
|
||||
y, prob = _step(y)
|
||||
y, p = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
sync = mx.async_eval(y)
|
||||
next_out = _step(y)
|
||||
sync.wait()
|
||||
yield y.item(), prob
|
||||
y, prob = next_out
|
||||
next_y, next_p = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
yield y.item(), p
|
||||
y, p = next_y, next_p
|
||||
|
||||
|
||||
def generate(
|
||||
@@ -283,6 +277,16 @@ def generate(
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
return config
|
||||
|
||||
|
||||
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
"""
|
||||
Load and initialize the model from a given path.
|
||||
@@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
quantization = config.get("quantization", None)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
|
||||
config = load_config(model_path)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
@@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
if quantization is not None:
|
||||
# for legacy models that don't have lm_head quant due to non-32 dims
|
||||
if "lm_head.scales" not in weights.keys():
|
||||
vocab_size = config["vocab_size"]
|
||||
extended_linear_class_predicate = (
|
||||
lambda layer: linear_class_predicate(layer)
|
||||
and layer.weight.shape[0] != vocab_size
|
||||
)
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=extended_linear_class_predicate,
|
||||
)
|
||||
# for models that have lm_head quant
|
||||
else:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=linear_class_predicate,
|
||||
)
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
@@ -395,10 +385,9 @@ def fetch_from_hub(
|
||||
model_path: Path, lazy: bool = False
|
||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||
model = load_model(model_path, lazy)
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
config = load_config(model_path)
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
|
||||
return model, config.to_dict(), tokenizer
|
||||
return model, config, tokenizer
|
||||
|
||||
|
||||
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
||||
@@ -543,10 +532,7 @@ def quantize_model(
|
||||
Tuple: Tuple containing quantized weights and config.
|
||||
"""
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
|
||||
)
|
||||
nn.quantize(model, q_group_size, q_bits)
|
||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
|
Reference in New Issue
Block a user