diff --git a/flux/flux/model.py b/flux/flux/model.py index 18ea70b0..d8ad9d9b 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -85,6 +85,8 @@ class Flux(nn.Module): def sanitize(self, weights): new_weights = {} for k, w in weights.items(): + if k.startswith("model.diffusion_model."): + k = k[22:] if k.endswith(".scale"): k = k[:-6] + ".weight" for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: