Update for nn.layers.distributed

This commit is contained in:
Angelos Katharopoulos 2025-03-22 16:35:32 -07:00
parent 02b007f19c
commit 208856520d
2 changed files with 29 additions and 14 deletions

View File

@ -109,10 +109,10 @@ class Flux(nn.Module):
block.txt_attn.num_heads //= N block.txt_attn.num_heads //= N
block.sharding_group = group block.sharding_group = group
block.img_attn.qkv = shard_linear( block.img_attn.qkv = shard_linear(
block.img_attn.qkv, "all-to-sharded", groups=3, group=group block.img_attn.qkv, "all-to-sharded", segments=3, group=group
) )
block.txt_attn.qkv = shard_linear( block.txt_attn.qkv = shard_linear(
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
) )
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group) shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group) shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
@ -131,11 +131,11 @@ class Flux(nn.Module):
block.linear1 = shard_linear( block.linear1 = shard_linear(
block.linear1, block.linear1,
"all-to-sharded", "all-to-sharded",
groups=[1 / 7, 2 / 7, 3 / 7], segments=[1 / 7, 2 / 7, 3 / 7],
group=group, group=group,
) )
block.linear2 = shard_linear( block.linear2 = shard_linear(
block.linear2, "sharded-to-all", groups=[1 / 5], group=group block.linear2, "sharded-to-all", segments=[1 / 5], group=group
) )
def __call__( def __call__(

View File

@ -62,6 +62,7 @@ if __name__ == "__main__":
parser.add_argument("--adapter") parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
parser.add_argument("--force-shard", action="store_true")
args = parser.parse_args() args = parser.parse_args()
# Load the models # Load the models
@ -77,8 +78,16 @@ if __name__ == "__main__":
nn.quantize(flux.clip, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init() group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1: if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group) flux.flow.shard(group)
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
else:
n_images //= group.size()
should_gather = True
if args.preload_models: if args.preload_models:
flux.ensure_models_are_loaded() flux.ensure_models_are_loaded()
@ -87,7 +96,7 @@ if __name__ == "__main__":
latent_size = to_latent_size(args.image_size) latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents( latents = flux.generate_latents(
args.prompt, args.prompt,
n_images=args.n_images, n_images=n_images,
num_steps=args.steps, num_steps=args.steps,
latent_size=latent_size, latent_size=latent_size,
guidance=args.guidance, guidance=args.guidance,
@ -97,8 +106,8 @@ if __name__ == "__main__":
# First we get and eval the conditioning # First we get and eval the conditioning
conditioning = next(latents) conditioning = next(latents)
mx.eval(conditioning) mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory() mx.reset_peak_memory()
# The following is not necessary but it may help in memory constrained # The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders. # systems by reusing the memory kept by the text encoders.
@ -106,36 +115,42 @@ if __name__ == "__main__":
del flux.clip del flux.clip
# Actual denoising loop # Actual denoising loop
for x_t in tqdm(latents, total=args.steps): for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
mx.eval(x_t) mx.eval(x_t)
# The following is not necessary but it may help in memory constrained # The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer. # systems by reusing the memory kept by the flow transformer.
del flux.flow del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory() mx.reset_peak_memory()
# Decode them into images # Decode them into images
decoded = [] decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1]) mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 peak_mem_decoding = mx.get_peak_memory() / 1024**3
peak_mem_overall = max( peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
) )
# Gather them if each node has different images
decoded = mx.concatenate(decoded, axis=0)
if should_gather:
decoded = mx.distributed.all_gather(decoded)
mx.eval(decoded)
if args.save_raw: if args.save_raw:
*name, suffix = args.output.split(".") *name, suffix = args.output.split(".")
name = ".".join(name) name = ".".join(name)
x = mx.concatenate(decoded, axis=0) x = decoded
x = (x * 255).astype(mx.uint8) x = (x * 255).astype(mx.uint8)
for i in range(len(x)): for i in range(len(x)):
im = Image.fromarray(np.array(x[i])) im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix])) im.save(".".join([name, str(i), suffix]))
else: else:
# Arrange them on a grid # Arrange them on a grid
x = mx.concatenate(decoded, axis=0) x = decoded
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)