mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Initial working distributed flux
This commit is contained in:
parent
c243370044
commit
7fbd1619eb
@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.distributed import shard_linear
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -96,6 +97,54 @@ class Flux(nn.Module):
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
if N == 1:
|
||||
return
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.num_heads //= N
|
||||
block.img_attn.num_heads //= N
|
||||
block.txt_attn.num_heads //= N
|
||||
block.img_attn.qkv = shard_linear(
|
||||
block.img_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||
)
|
||||
block.txt_attn.qkv = shard_linear(
|
||||
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||
)
|
||||
block.img_attn.proj = shard_linear(
|
||||
block.img_attn.proj, "sharded-to-all", group=group
|
||||
)
|
||||
block.txt_attn.proj = shard_linear(
|
||||
block.txt_attn.proj, "sharded-to-all", group=group
|
||||
)
|
||||
block.img_mlp.layers[0] = shard_linear(
|
||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.txt_mlp.layers[0] = shard_linear(
|
||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.img_mlp.layers[2] = shard_linear(
|
||||
block.img_mlp.layers[2], "sharded-to-all", group=group
|
||||
)
|
||||
block.txt_mlp.layers[2] = shard_linear(
|
||||
block.txt_mlp.layers[2], "sharded-to-all", group=group
|
||||
)
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.num_heads //= N
|
||||
block.hidden_size //= N
|
||||
block.linear1 = shard_linear(
|
||||
block.linear1,
|
||||
"all-to-sharded",
|
||||
groups=[1 / 7, 2 / 7, 3 / 7],
|
||||
group=group,
|
||||
)
|
||||
block.linear2 = shard_linear(
|
||||
block.linear2, "sharded-to-all", groups=[1 / 5], group=group
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: mx.array,
|
||||
|
@ -76,6 +76,10 @@ if __name__ == "__main__":
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
group = mx.distributed.init()
|
||||
if group.size() > 1:
|
||||
flux.flow.shard(group)
|
||||
|
||||
if args.preload_models:
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user