mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -60,13 +60,10 @@ def quantize(weights, config, args):
|
||||
model.update(all_weights)
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
nn.quantize(
|
||||
model,
|
||||
args.q_group_size,
|
||||
args.q_bits,
|
||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
||||
and m.weight.shape[0] != 8,
|
||||
)
|
||||
|
||||
# Extract the subset of quantized weights:
|
||||
|
@@ -217,11 +217,7 @@ def load_model(folder: str):
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
model = Mixtral(model_args)
|
||||
if quantization is not None:
|
||||
# TODO: Quantize gate matrices when < 32 tiles supported
|
||||
quantization["linear_class_predicate"] = (
|
||||
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
|
||||
)
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
nn.quantize(model, **quantization)
|
||||
|
||||
model.update(weights)
|
||||
return model, tokenizer
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.8.0
|
||||
mlx>=0.11.0
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
||||
|
Reference in New Issue
Block a user