mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Finetune all layers
This commit is contained in:
@@ -109,22 +109,22 @@ class FinetuningDataset:
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
while True:
|
||||
indices = mx.random.randint(0, len(self.latents), (batch_size,)).tolist()
|
||||
x = mx.concatenate([self.latents[i] for i in indices])
|
||||
t5 = mx.concatenate([self.t5_features[i] for i in indices])
|
||||
clip = mx.concatenate([self.clip_features[i] for i in indices])
|
||||
mx.eval(x, t5, clip)
|
||||
yield x, t5, clip
|
||||
indices = mx.random.randint(0, len(self.latents), (batch_size,))
|
||||
yield xs[indices], t5[indices], clip[indices]
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux, args):
|
||||
lora_layers = []
|
||||
rank = args.lora_rank
|
||||
for name, mod in flux.flow.named_modules():
|
||||
if ".img_attn" not in name and ".txt_attn" not in name:
|
||||
continue
|
||||
if ".qkv" in name or ".proj" in name:
|
||||
if ("double_blocks" in name or "single_blocks" in name) and isinstance(
|
||||
mod, nn.Linear
|
||||
):
|
||||
lora_layers.append((name, LoRALinear.from_base(mod, r=rank)))
|
||||
flux.flow.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
Reference in New Issue
Block a user