mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Stable diffusion - check model weights shape and support int for "attention_head_dim" (#85)
* Allow integer as attention_head_dim * Reshape downloaded weights to match model if there is a mismatch
This commit is contained in:
parent
86cae9ba57
commit
fc1495abaa
@ -76,6 +76,10 @@ def map_unet_weights(key, value):
|
|||||||
if "conv_shortcut.weight" in key:
|
if "conv_shortcut.weight" in key:
|
||||||
value = value.squeeze()
|
value = value.squeeze()
|
||||||
|
|
||||||
|
# Transform the weights from 1x1 convs to linear
|
||||||
|
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||||
|
value = value.squeeze()
|
||||||
|
|
||||||
if len(value.shape) == 4:
|
if len(value.shape) == 4:
|
||||||
value = value.transpose(0, 2, 3, 1)
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
@ -184,7 +188,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
out_channels=config["out_channels"],
|
out_channels=config["out_channels"],
|
||||||
block_out_channels=config["block_out_channels"],
|
block_out_channels=config["block_out_channels"],
|
||||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||||
num_attention_heads=config["attention_head_dim"],
|
num_attention_heads=[config["attention_head_dim"]] * n_blocks
|
||||||
|
if isinstance(config["attention_head_dim"], int)
|
||||||
|
else config["attention_head_dim"],
|
||||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||||
norm_num_groups=config["norm_num_groups"],
|
norm_num_groups=config["norm_num_groups"],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user