diff --git a/flux/flux/model.py b/flux/flux/model.py index d8ad9d9b..df0903d3 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -5,6 +5,7 @@ from typing import Optional import mlx.core as mx import mlx.nn as nn +from mlx.nn.layers.distributed import shard_linear from .layers import ( DoubleStreamBlock, @@ -96,6 +97,54 @@ class Flux(nn.Module): new_weights[k] = w return new_weights + def shard(self, group: Optional[mx.distributed.Group] = None): + group = group or mx.distributed.init() + N = group.size() + if N == 1: + return + + for block in self.double_blocks: + block.num_heads //= N + block.img_attn.num_heads //= N + block.txt_attn.num_heads //= N + block.img_attn.qkv = shard_linear( + block.img_attn.qkv, "all-to-sharded", groups=3, group=group + ) + block.txt_attn.qkv = shard_linear( + block.txt_attn.qkv, "all-to-sharded", groups=3, group=group + ) + block.img_attn.proj = shard_linear( + block.img_attn.proj, "sharded-to-all", group=group + ) + block.txt_attn.proj = shard_linear( + block.txt_attn.proj, "sharded-to-all", group=group + ) + block.img_mlp.layers[0] = shard_linear( + block.img_mlp.layers[0], "all-to-sharded", group=group + ) + block.txt_mlp.layers[0] = shard_linear( + block.txt_mlp.layers[0], "all-to-sharded", group=group + ) + block.img_mlp.layers[2] = shard_linear( + block.img_mlp.layers[2], "sharded-to-all", group=group + ) + block.txt_mlp.layers[2] = shard_linear( + block.txt_mlp.layers[2], "sharded-to-all", group=group + ) + + for block in self.single_blocks: + block.num_heads //= N + block.hidden_size //= N + block.linear1 = shard_linear( + block.linear1, + "all-to-sharded", + groups=[1 / 7, 2 / 7, 3 / 7], + group=group, + ) + block.linear2 = shard_linear( + block.linear2, "sharded-to-all", groups=[1 / 5], group=group + ) + def __call__( self, img: mx.array, diff --git a/flux/txt2image.py b/flux/txt2image.py index 5ebec81a..5104c5c0 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -76,6 +76,10 @@ if __name__ == "__main__": nn.quantize(flux.t5, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate) + group = mx.distributed.init() + if group.size() > 1: + flux.flow.shard(group) + if args.preload_models: flux.ensure_models_are_loaded()