Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -254,7 +254,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items())))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.quantize(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {

View File

@@ -1,4 +1,4 @@
mlx>=0.8
mlx>=0.11
numba
numpy
torch

View File

@@ -32,7 +32,7 @@ def load_model(
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)
model.update(weights)
mx.eval(model.parameters())

View File

@@ -196,7 +196,7 @@ class TextDecoder(nn.Module):
)
x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
return self.token_embedding.as_linear(x), kv_cache, cross_qk
class Whisper(nn.Module):