mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Merge 282304a87d
into 4b2a0df237
This commit is contained in:
commit
c79c611ea7
@ -130,7 +130,17 @@ def map_vae_weights(key, value):
|
|||||||
if "upsamplers" in key:
|
if "upsamplers" in key:
|
||||||
key = key.replace("upsamplers.0.conv", "upsample")
|
key = key.replace("upsamplers.0.conv", "upsample")
|
||||||
|
|
||||||
# Map attention layers
|
# Map attention layers in SD-2-1-base:VAE
|
||||||
|
if "key" in key:
|
||||||
|
key = key.replace("key", "key_proj")
|
||||||
|
if "proj_attn" in key:
|
||||||
|
key = key.replace("proj_attn", "out_proj")
|
||||||
|
if "query" in key:
|
||||||
|
key = key.replace("query", "query_proj")
|
||||||
|
if "value" in key:
|
||||||
|
key = key.replace("value", "value_proj")
|
||||||
|
|
||||||
|
# Map attention layers in SDXL Turbo
|
||||||
if "to_k" in key:
|
if "to_k" in key:
|
||||||
key = key.replace("to_k", "key_proj")
|
key = key.replace("to_k", "key_proj")
|
||||||
if "to_out.0" in key:
|
if "to_out.0" in key:
|
||||||
@ -140,6 +150,7 @@ def map_vae_weights(key, value):
|
|||||||
if "to_v" in key:
|
if "to_v" in key:
|
||||||
key = key.replace("to_v", "value_proj")
|
key = key.replace("to_v", "value_proj")
|
||||||
|
|
||||||
|
|
||||||
# Map the mid block
|
# Map the mid block
|
||||||
if "mid_block.resnets.0" in key:
|
if "mid_block.resnets.0" in key:
|
||||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||||
|
Loading…
Reference in New Issue
Block a user