Stable diffusion XL working; "add_embedding" layer not implemented

This commit is contained in:
Pawel Kowalski
2023-12-20 17:25:05 +01:00
parent fe2291710f
commit dac547367d
6 changed files with 37 additions and 10 deletions

View File

@@ -136,9 +136,9 @@ class StableDiffusionXL(StableDiffusion):
)
if cfg_weight > 1:
tokens2 += (
[self.tokenizer.tokenize(negative_text_2)]
[self.tokenizer_2.tokenize(negative_text_2)]
if text_2
else [self.tokenizer.tokenize(negative_text)]
else [self.tokenizer_2.tokenize(negative_text)]
)
lengths2 = [len(t) for t in tokens2]
N = max(lengths2)

View File

@@ -73,8 +73,8 @@ class CLIPTextModel(nn.Module):
class CLIPTextModelWithProjection(CLIPTextModel):
def __init__(self, config: CLIPTextModelConfig):
super().__init__(config)
self.projection = nn.Linear(config.model_dims, config.projection_dims)
self.text_projection = nn.Linear(config.model_dims, config.projection_dims)
def __call__(self, x):
x = super().__call__(x)
return self.projection(x)
return self.text_projection(x)

View File

@@ -38,7 +38,21 @@ class UNetConfig:
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
norm_num_groups: int = (32,)
down_block_types: Tuple[str] = (
(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
@dataclass

View File

@@ -10,7 +10,7 @@ from mlx.utils import tree_unflatten
from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from mlx.utils import tree_unflatten, tree_flatten
from .clip import CLIPTextModel, CLIPTextModelWithProjection
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
@@ -195,6 +195,12 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
# debug
bar = tree_flatten(model)
missing_weights = [w[0] for w in weights if w[0] not in [b[0] for b in bar]]
if missing_weights:
print("warning: missing weights")
print(missing_weights)
model.update(tree_unflatten(weights))
@@ -226,6 +232,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
else config["attention_head_dim"],
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"],
transformer_layers_per_block=config["transformer_layers_per_block"],
)
)

View File

@@ -308,7 +308,8 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention=config.down_block_types[i]
== "CrossAttnDownBlock2D",
)
for i, (in_channels, out_channels) in enumerate(
zip(block_channels, block_channels[1:])
@@ -357,7 +358,9 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention=(i < len(config.block_out_channels) - 1),
# add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention=list(reversed(config.up_block_types))[i]
== "CrossAttnUpBlock2D",
)
for i, (in_channels, out_channels, prev_out_channels) in reversed(
list(

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
from PIL import Image
from tqdm import tqdm
from stable_diffusion import StableDiffusion
from stable_diffusion import StableDiffusion, StableDiffusionXL
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -22,7 +22,8 @@ if __name__ == "__main__":
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
sd = StableDiffusion()
sd = StableDiffusionXL("stabilityai/stable-diffusion-xl-base-1.0")
# sd = StableDiffusion()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(