mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix the seed for data parallel
This commit is contained in:
parent
a1e259607e
commit
c109d9b596
@ -77,18 +77,24 @@ if __name__ == "__main__":
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
# Figure out what kind of distributed generation we should do
|
||||
group = mx.distributed.init()
|
||||
n_images = args.n_images
|
||||
should_gather = False
|
||||
if group.size() > 1:
|
||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||
flux.flow.shard(group)
|
||||
if args.seed is None:
|
||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||
else:
|
||||
n_images //= group.size()
|
||||
should_gather = True
|
||||
|
||||
# If we are sharding we should have the same seed and if we are doing
|
||||
# data parallel generation we should have different seeds
|
||||
if args.seed is None:
|
||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||
if should_gather:
|
||||
args.seed = args.seed + group.rank()
|
||||
|
||||
if args.preload_models:
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user