diff --git a/flux/flux/model.py b/flux/flux/model.py index bd74a393..c524edf3 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -109,10 +109,10 @@ class Flux(nn.Module): block.txt_attn.num_heads //= N block.sharding_group = group block.img_attn.qkv = shard_linear( - block.img_attn.qkv, "all-to-sharded", groups=3, group=group + block.img_attn.qkv, "all-to-sharded", segments=3, group=group ) block.txt_attn.qkv = shard_linear( - block.txt_attn.qkv, "all-to-sharded", groups=3, group=group + block.txt_attn.qkv, "all-to-sharded", segments=3, group=group ) shard_inplace(block.img_attn.proj, "sharded-to-all", group=group) shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group) @@ -131,11 +131,11 @@ class Flux(nn.Module): block.linear1 = shard_linear( block.linear1, "all-to-sharded", - groups=[1 / 7, 2 / 7, 3 / 7], + segments=[1 / 7, 2 / 7, 3 / 7], group=group, ) block.linear2 = shard_linear( - block.linear2, "sharded-to-all", groups=[1 / 5], group=group + block.linear2, "sharded-to-all", segments=[1 / 5], group=group ) def __call__( diff --git a/flux/txt2image.py b/flux/txt2image.py index 5104c5c0..98cd8633 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -62,6 +62,7 @@ if __name__ == "__main__": parser.add_argument("--adapter") parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") + parser.add_argument("--force-shard", action="store_true") args = parser.parse_args() # Load the models @@ -77,8 +78,16 @@ if __name__ == "__main__": nn.quantize(flux.clip, class_predicate=quantization_predicate) group = mx.distributed.init() + n_images = args.n_images + should_gather = False if group.size() > 1: - flux.flow.shard(group) + if args.force_shard or n_images < group.size() or n_images % group.size() != 0: + flux.flow.shard(group) + if args.seed is None: + args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item() + else: + n_images //= group.size() + should_gather = True if args.preload_models: flux.ensure_models_are_loaded() @@ -87,7 +96,7 @@ if __name__ == "__main__": latent_size = to_latent_size(args.image_size) latents = flux.generate_latents( args.prompt, - n_images=args.n_images, + n_images=n_images, num_steps=args.steps, latent_size=latent_size, guidance=args.guidance, @@ -97,8 +106,8 @@ if __name__ == "__main__": # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 - mx.metal.reset_peak_memory() + peak_mem_conditioning = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the text encoders. @@ -106,36 +115,42 @@ if __name__ == "__main__": del flux.clip # Actual denoising loop - for x_t in tqdm(latents, total=args.steps): + for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0): mx.eval(x_t) # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 - mx.metal.reset_peak_memory() + peak_mem_generation = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() # Decode them into images decoded = [] for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_decoding = mx.get_peak_memory() / 1024**3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) + # Gather them if each node has different images + decoded = mx.concatenate(decoded, axis=0) + if should_gather: + decoded = mx.distributed.all_gather(decoded) + mx.eval(decoded) + if args.save_raw: *name, suffix = args.output.split(".") name = ".".join(name) - x = mx.concatenate(decoded, axis=0) + x = decoded x = (x * 255).astype(mx.uint8) for i in range(len(x)): im = Image.fromarray(np.array(x[i])) im.save(".".join([name, str(i), suffix])) else: # Arrange them on a grid - x = mx.concatenate(decoded, axis=0) + x = decoded x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) B, H, W, C = x.shape x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)