Stable diffusion - check model weights shape and support int for "attention_head_dim" (#85)

* Allow integer as attention_head_dim
* Reshape downloaded weights to match model if there is a mismatch
This commit is contained in:
Pawel Kowalski 2023-12-15 22:01:02 +01:00 committed by GitHub
parent 86cae9ba57
commit fc1495abaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -76,6 +76,10 @@ def map_unet_weights(key, value):
if "conv_shortcut.weight" in key: if "conv_shortcut.weight" in key:
value = value.squeeze() value = value.squeeze()
# Transform the weights from 1x1 convs to linear
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
value = value.squeeze()
if len(value.shape) == 4: if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1) value = value.transpose(0, 2, 3, 1)
@ -184,7 +188,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
out_channels=config["out_channels"], out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"], block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks, layers_per_block=[config["layers_per_block"]] * n_blocks,
num_attention_heads=config["attention_head_dim"], num_attention_heads=[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
else config["attention_head_dim"],
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks, cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"], norm_num_groups=config["norm_num_groups"],
) )