Allow loading from diffusers ckpt

This commit is contained in:
Angelos Katharopoulos 2024-11-22 20:52:50 -08:00
parent 042280ce50
commit 516d0e3af0

View File

@ -85,6 +85,8 @@ class Flux(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
new_weights = {} new_weights = {}
for k, w in weights.items(): for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"): if k.endswith(".scale"):
k = k[:-6] + ".weight" k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: