Initial working distributed flux

This commit is contained in:
Angelos Katharopoulos 2025-03-03 23:17:30 -08:00
parent c243370044
commit 7fbd1619eb
2 changed files with 53 additions and 0 deletions

View File

@ -5,6 +5,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_linear
from .layers import (
DoubleStreamBlock,
@ -96,6 +97,54 @@ class Flux(nn.Module):
new_weights[k] = w
return new_weights
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return
for block in self.double_blocks:
block.num_heads //= N
block.img_attn.num_heads //= N
block.txt_attn.num_heads //= N
block.img_attn.qkv = shard_linear(
block.img_attn.qkv, "all-to-sharded", groups=3, group=group
)
block.txt_attn.qkv = shard_linear(
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group
)
block.img_attn.proj = shard_linear(
block.img_attn.proj, "sharded-to-all", group=group
)
block.txt_attn.proj = shard_linear(
block.txt_attn.proj, "sharded-to-all", group=group
)
block.img_mlp.layers[0] = shard_linear(
block.img_mlp.layers[0], "all-to-sharded", group=group
)
block.txt_mlp.layers[0] = shard_linear(
block.txt_mlp.layers[0], "all-to-sharded", group=group
)
block.img_mlp.layers[2] = shard_linear(
block.img_mlp.layers[2], "sharded-to-all", group=group
)
block.txt_mlp.layers[2] = shard_linear(
block.txt_mlp.layers[2], "sharded-to-all", group=group
)
for block in self.single_blocks:
block.num_heads //= N
block.hidden_size //= N
block.linear1 = shard_linear(
block.linear1,
"all-to-sharded",
groups=[1 / 7, 2 / 7, 3 / 7],
group=group,
)
block.linear2 = shard_linear(
block.linear2, "sharded-to-all", groups=[1 / 5], group=group
)
def __call__(
self,
img: mx.array,

View File

@ -76,6 +76,10 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
if args.preload_models:
flux.ensure_models_are_loaded()