mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -254,7 +254,7 @@ def quantize(weights, config, args):
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
|
||||
# Quantize the model:
|
||||
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
|
||||
nn.quantize(model, args.q_group_size, args.q_bits)
|
||||
|
||||
# Update the config:
|
||||
quantized_config["quantization"] = {
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.8
|
||||
mlx>=0.11
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
|
@@ -32,7 +32,7 @@ def load_model(
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
nn.quantize(model, **quantization)
|
||||
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
|
@@ -196,7 +196,7 @@ class TextDecoder(nn.Module):
|
||||
)
|
||||
|
||||
x = self.ln(x)
|
||||
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
|
||||
return self.token_embedding.as_linear(x), kv_cache, cross_qk
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
|
Reference in New Issue
Block a user