Finetune all layers

This commit is contained in:
Angelos Katharopoulos
2024-10-08 16:31:33 -07:00
parent d9c5fd5ba4
commit 27aaff8f31

View File

@@ -109,22 +109,22 @@ class FinetuningDataset:
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
while True:
indices = mx.random.randint(0, len(self.latents), (batch_size,)).tolist()
x = mx.concatenate([self.latents[i] for i in indices])
t5 = mx.concatenate([self.t5_features[i] for i in indices])
clip = mx.concatenate([self.clip_features[i] for i in indices])
mx.eval(x, t5, clip)
yield x, t5, clip
indices = mx.random.randint(0, len(self.latents), (batch_size,))
yield xs[indices], t5[indices], clip[indices]
def linear_to_lora_layers(flux, args):
lora_layers = []
rank = args.lora_rank
for name, mod in flux.flow.named_modules():
if ".img_attn" not in name and ".txt_attn" not in name:
continue
if ".qkv" in name or ".proj" in name:
if ("double_blocks" in name or "single_blocks" in name) and isinstance(
mod, nn.Linear
):
lora_layers.append((name, LoRALinear.from_base(mod, r=rank)))
flux.flow.update_modules(tree_unflatten(lora_layers))