diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 30d0410a..b64e3117 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -20,9 +20,11 @@ from .utils import ( class FluxPipeline: - def __init__(self, name: str): + def __init__(self, name: str, t5_padding: bool = True): self.dtype = mx.bfloat16 self.name = name + self.t5_padding = t5_padding + self.ae = load_ae(name) self.flow = load_flow_model(name) self.clip = load_clip(name) @@ -44,7 +46,7 @@ class FluxPipeline: self.clip = load_clip(self.name) def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text) + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) clip_tokens = self.clip_tokenizer.encode(text) return t5_tokens, clip_tokens @@ -235,3 +237,10 @@ class FluxPipeline: if isinstance(module, nn.Linear): loras.append((name, LoRALinear.from_base(module, r=rank))) block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/lora.py b/flux/flux/lora.py index 2bf6fb69..0cdf1d4f 100644 --- a/flux/flux/lora.py +++ b/flux/flux/lora.py @@ -32,9 +32,9 @@ class LoRALinear(nn.Module): output_dims, input_dims = weight.shape fused_linear = nn.Linear(input_dims, output_dims, bias=bias) - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a + lora_b = self.scale * self.lora_b.T + lora_a = self.lora_a.T + fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype) if bias: fused_linear.bias = linear.bias diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py index 16f362f6..8716fe4b 100644 --- a/flux/flux/tokenizers.py +++ b/flux/flux/tokenizers.py @@ -165,17 +165,17 @@ class T5Tokenizer: tokens = [self.bos_token] + tokens if append_eos and self.eos_token >= 0: tokens.append(self.eos_token) - if len(tokens) < self.max_length and self.pad_token >= 0: + if pad and len(tokens) < self.max_length and self.pad_token >= 0: tokens += [self.pad_token] * (self.max_length - len(tokens)) return tokens - def encode(self, text): + def encode(self, text, pad=True): if not isinstance(text, list): - return self.encode([text]) + return self.encode([text], pad=pad) pad_token = self.pad_token if self.pad_token >= 0 else 0 - tokens = self.tokenize(text) + tokens = self.tokenize(text, pad=pad) length = max(len(t) for t in tokens) for t in tokens: t.extend([pad_token] * (length - len(t))) diff --git a/flux/flux/utils.py b/flux/flux/utils.py index e506cda1..5c935b98 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -202,6 +202,6 @@ def load_clip_tokenizer(name: str): return CLIPTokenizer(bpe_ranks, vocab, max_length=77) -def load_t5_tokenizer(name: str): +def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) diff --git a/flux/txt2image.py b/flux/txt2image.py index 5808a059..cd9774a4 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -27,6 +27,16 @@ def quantization_predicate(name, m): return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 +def load_adapter(flux, adapter_file, fuse=False): + weights, lora_config = mx.load(adapter_file, return_metadata=True) + rank = int(lora_config["lora_rank"]) + num_blocks = int(lora_config["lora_blocks"]) + flux.linear_to_lora_layers(rank, num_blocks) + flux.flow.load_weights(list(weights.items()), strict=False) + if fuse: + flux.fuse_lora_layers() + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate images from a textual prompt using stable diffusion" @@ -47,12 +57,18 @@ if __name__ == "__main__": parser.add_argument("--save-raw", action="store_true") parser.add_argument("--seed", type=int) parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--adapter") + parser.add_argument("--fuse-adapter", action="store_true") + parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") args = parser.parse_args() # Load the models - flux = FluxPipeline("flux-" + args.model) + flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) args.steps = args.steps or (50 if args.model == "dev" else 2) + if args.adapter: + load_adapter(flux, args.adapter, fuse=args.fuse_adapter) + if args.quantize: nn.quantize(flux.flow, class_predicate=quantization_predicate) nn.quantize(flux.t5, class_predicate=quantization_predicate)