mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Quantize embedding / Update quantize API (#680)
* more async eval * quantize embedding / update quantize api * more updates for quantize * update for quantize embeddings * update sd quant API * update sdxl quants * error for datasets < batch_size * async * fix config loading * fix quant * fix tests * fix req * remove lm head if tie weights is true * fix test
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
@@ -34,10 +35,18 @@ if __name__ == "__main__":
|
||||
# Load the models
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
nn.quantize(
|
||||
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(
|
||||
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
|
||||
nn.quantize(sd.text_encoder_1)
|
||||
nn.quantize(sd.text_encoder_2)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 0.0
|
||||
args.steps = args.steps or 2
|
||||
else:
|
||||
@@ -45,8 +54,10 @@ if __name__ == "__main__":
|
||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||
)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
nn.quantize(
|
||||
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.6
|
||||
mlx>=0.11
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
|
@@ -3,8 +3,8 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn import QuantizedLinear
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -34,9 +34,13 @@ if __name__ == "__main__":
|
||||
if args.model == "sdxl":
|
||||
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
||||
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
nn.quantize(
|
||||
sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(
|
||||
sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 0.0
|
||||
args.steps = args.steps or 2
|
||||
else:
|
||||
@@ -44,8 +48,10 @@ if __name__ == "__main__":
|
||||
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||
)
|
||||
if args.quantize:
|
||||
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||
nn.quantize(
|
||||
sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
||||
)
|
||||
nn.quantize(sd.unet, group_size=32, bits=8)
|
||||
args.cfg = args.cfg or 7.5
|
||||
args.steps = args.steps or 50
|
||||
|
||||
|
Reference in New Issue
Block a user