mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Initial working distributed flux
This commit is contained in:
parent
c243370044
commit
7fbd1619eb
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from mlx.nn.layers.distributed import shard_linear
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@ -96,6 +97,54 @@ class Flux(nn.Module):
|
|||||||
new_weights[k] = w
|
new_weights[k] = w
|
||||||
return new_weights
|
return new_weights
|
||||||
|
|
||||||
|
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
N = group.size()
|
||||||
|
if N == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
block.num_heads //= N
|
||||||
|
block.img_attn.num_heads //= N
|
||||||
|
block.txt_attn.num_heads //= N
|
||||||
|
block.img_attn.qkv = shard_linear(
|
||||||
|
block.img_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||||
|
)
|
||||||
|
block.txt_attn.qkv = shard_linear(
|
||||||
|
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||||
|
)
|
||||||
|
block.img_attn.proj = shard_linear(
|
||||||
|
block.img_attn.proj, "sharded-to-all", group=group
|
||||||
|
)
|
||||||
|
block.txt_attn.proj = shard_linear(
|
||||||
|
block.txt_attn.proj, "sharded-to-all", group=group
|
||||||
|
)
|
||||||
|
block.img_mlp.layers[0] = shard_linear(
|
||||||
|
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||||
|
)
|
||||||
|
block.txt_mlp.layers[0] = shard_linear(
|
||||||
|
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||||
|
)
|
||||||
|
block.img_mlp.layers[2] = shard_linear(
|
||||||
|
block.img_mlp.layers[2], "sharded-to-all", group=group
|
||||||
|
)
|
||||||
|
block.txt_mlp.layers[2] = shard_linear(
|
||||||
|
block.txt_mlp.layers[2], "sharded-to-all", group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.single_blocks:
|
||||||
|
block.num_heads //= N
|
||||||
|
block.hidden_size //= N
|
||||||
|
block.linear1 = shard_linear(
|
||||||
|
block.linear1,
|
||||||
|
"all-to-sharded",
|
||||||
|
groups=[1 / 7, 2 / 7, 3 / 7],
|
||||||
|
group=group,
|
||||||
|
)
|
||||||
|
block.linear2 = shard_linear(
|
||||||
|
block.linear2, "sharded-to-all", groups=[1 / 5], group=group
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
img: mx.array,
|
img: mx.array,
|
||||||
|
@ -76,6 +76,10 @@ if __name__ == "__main__":
|
|||||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||||
|
|
||||||
|
group = mx.distributed.init()
|
||||||
|
if group.size() > 1:
|
||||||
|
flux.flow.shard(group)
|
||||||
|
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
flux.ensure_models_are_loaded()
|
flux.ensure_models_are_loaded()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user