Quantize embedding / Update quantize API (#680)

* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
This commit is contained in:
Awni Hannun
2024-04-18 18:16:10 -07:00
committed by GitHub
parent f5f189e48a
commit 2146bcd7ee
28 changed files with 108 additions and 190 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
@@ -34,10 +35,18 @@ if __name__ == "__main__":
# Load the models
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1)
QuantizedLinear.quantize_module(sd.text_encoder_2)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
nn.quantize(
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.text_encoder_1)
nn.quantize(sd.text_encoder_2)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
@@ -45,8 +54,10 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
nn.quantize(
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50

View File

@@ -1,4 +1,4 @@
mlx>=0.6
mlx>=0.11
huggingface-hub
regex
numpy

View File

@@ -3,8 +3,8 @@
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn import QuantizedLinear
from PIL import Image
from tqdm import tqdm
@@ -34,9 +34,13 @@ if __name__ == "__main__":
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1)
QuantizedLinear.quantize_module(sd.text_encoder_2)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
nn.quantize(
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
@@ -44,8 +48,10 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
nn.quantize(
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
)
nn.quantize(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50