fix mapping for sdxl turbo

This commit is contained in:
Pranav Veldurthi 2024-09-04 09:09:52 -04:00
parent 1fbc9361e4
commit 282304a87d

View File

@ -130,7 +130,17 @@ def map_vae_weights(key, value):
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
# Map attention layers in SD-2-1-base:VAE
if "key" in key:
key = key.replace("key", "key_proj")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map attention layers in SDXL Turbo
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
@ -141,16 +151,6 @@ def map_vae_weights(key, value):
key = key.replace("to_v", "value_proj")
# Map attention layers in SD-2-1-base:VAE
if "key" in key:
key = key.replace("key", "to_k")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")