Enable generation with a trained adapter

This commit is contained in:
Angelos Katharopoulos
2024-10-10 01:34:50 -07:00
parent ecd8828e33
commit 807bd66b80
5 changed files with 36 additions and 11 deletions

View File

@@ -20,9 +20,11 @@ from .utils import (
class FluxPipeline:
def __init__(self, name: str):
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
@@ -44,7 +46,7 @@ class FluxPipeline:
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text)
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
@@ -235,3 +237,10 @@ class FluxPipeline:
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

View File

@@ -32,9 +32,9 @@ class LoRALinear(nn.Module):
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
lora_b = self.scale * self.lora_b.T
lora_a = self.lora_a.T
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
if bias:
fused_linear.bias = linear.bias

View File

@@ -165,17 +165,17 @@ class T5Tokenizer:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
if len(tokens) < self.max_length and self.pad_token >= 0:
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
tokens += [self.pad_token] * (self.max_length - len(tokens))
return tokens
def encode(self, text):
def encode(self, text, pad=True):
if not isinstance(text, list):
return self.encode([text])
return self.encode([text], pad=pad)
pad_token = self.pad_token if self.pad_token >= 0 else 0
tokens = self.tokenize(text)
tokens = self.tokenize(text, pad=pad)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([pad_token] * (length - len(t)))

View File

@@ -202,6 +202,6 @@ def load_clip_tokenizer(name: str):
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)

View File

@@ -27,6 +27,16 @@ def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def load_adapter(flux, adapter_file, fuse=False):
weights, lora_config = mx.load(adapter_file, return_metadata=True)
rank = int(lora_config["lora_rank"])
num_blocks = int(lora_config["lora_blocks"])
flux.linear_to_lora_layers(rank, num_blocks)
flux.flow.load_weights(list(weights.items()), strict=False)
if fuse:
flux.fuse_lora_layers()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
@@ -47,12 +57,18 @@ if __name__ == "__main__":
parser.add_argument("--save-raw", action="store_true")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model)
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
args.steps = args.steps or (50 if args.model == "dev" else 2)
if args.adapter:
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)