From c109d9b596c08b01e195f32aa6f76e9c1afba40c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 22 Mar 2025 16:50:28 -0700 Subject: [PATCH] Fix the seed for data parallel --- flux/txt2image.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/flux/txt2image.py b/flux/txt2image.py index 2d3857e2..fd59b711 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -77,18 +77,24 @@ if __name__ == "__main__": nn.quantize(flux.t5, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate) + # Figure out what kind of distributed generation we should do group = mx.distributed.init() n_images = args.n_images should_gather = False if group.size() > 1: if args.force_shard or n_images < group.size() or n_images % group.size() != 0: flux.flow.shard(group) - if args.seed is None: - args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item() else: n_images //= group.size() should_gather = True + # If we are sharding we should have the same seed and if we are doing + # data parallel generation we should have different seeds + if args.seed is None: + args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item() + if should_gather: + args.seed = args.seed + group.rank() + if args.preload_models: flux.ensure_models_are_loaded()