partial implementation of SD XL, incl. CLIP with projection, but doesn't produce good output

This commit is contained in:
Pawel Kowalski
2023-12-19 13:21:07 +01:00
parent e091fd2abc
commit c63870d3cb
3 changed files with 76 additions and 18 deletions

View File

@@ -68,3 +68,13 @@ class CLIPTextModel(nn.Module):
# Apply the final layernorm and return
return self.final_layer_norm(x)
class CLIPTextModelWithProjection(CLIPTextModel):
def __init__(self, config: CLIPTextModelConfig):
super().__init__(config)
self.projection = nn.Linear(config.model_dims, config.projection_dims)
def __call__(self, x):
x = super().__call__(x)
return self.projection(x)

View File

@@ -23,6 +23,7 @@ class CLIPTextModelConfig:
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dims: int = 512
@dataclass

View File

@@ -12,7 +12,7 @@ from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from .clip import CLIPTextModel
from .clip import CLIPTextModel, CLIPTextModelWithProjection
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
from .tokenizer import Tokenizer
from .unet import UNetModel
@@ -31,7 +31,39 @@ _MODELS = {
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
}
},
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
"stabilityai/sdxl-turbo": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder_2_config": "text_encoder_2/config.json",
"text_encoder": "text_encoder/model.safetensors",
"text_encoder_2": "text_encoder_2/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
"tokenizer_2_merges": "tokenizer_2/merges.txt",
},
# See https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 for the model details and license
"stabilityai/stable-diffusion-xl-base-1.0": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder_2_config": "text_encoder_2/config.json",
"text_encoder": "text_encoder/model.safetensors",
"text_encoder_2": "text_encoder_2/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
"tokenizer_2_merges": "tokenizer_2/merges.txt",
},
}
@@ -205,27 +237,42 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
return model
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_text_encoder(
key: str = _DEFAULT_MODEL, float16: bool = False, name="text_encoder"
):
"""Load the stable diffusion text encoder from Hugging Face Hub."""
_check_key(key, "load_text_encoder")
_check_key(key, f"load_{name}")
# Download the config and create the model
text_encoder_config = _MODELS[key]["text_encoder_config"]
text_encoder_config = _MODELS[key][f"{name}_config"]
with open(hf_hub_download(key, text_encoder_config)) as f:
config = json.load(f)
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
# Check whether encoder has projection layer
if "CLIPTextModelWithProjection" in config["architectures"]:
model = CLIPTextModelWithProjection(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dims=config["projection_dim"],
)
)
else:
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
)
)
)
# Download the weights and map them into the model
text_encoder_weights = _MODELS[key]["text_encoder"]
text_encoder_weights = _MODELS[key][name]
weight_file = hf_hub_download(key, text_encoder_weights)
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
@@ -277,14 +324,14 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
)
def load_tokenizer(key: str = _DEFAULT_MODEL):
_check_key(key, "load_tokenizer")
def load_tokenizer(key: str = _DEFAULT_MODEL, name="tokenizer"):
_check_key(key, f"load_{name}")
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
vocab_file = hf_hub_download(key, _MODELS[key][f"{name}_vocab"])
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
merges_file = hf_hub_download(key, _MODELS[key][f"{name}_merges"])
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]