Fix the seed for data parallel

This commit is contained in:
Angelos Katharopoulos 2025-03-22 16:50:28 -07:00
parent a1e259607e
commit c109d9b596

View File

@ -77,18 +77,24 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
# Figure out what kind of distributed generation we should do
group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group)
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
else:
n_images //= group.size()
should_gather = True
# If we are sharding we should have the same seed and if we are doing
# data parallel generation we should have different seeds
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
if should_gather:
args.seed = args.seed + group.rank()
if args.preload_models:
flux.ensure_models_are_loaded()