mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix the seed for data parallel
This commit is contained in:
parent
a1e259607e
commit
c109d9b596
@ -77,18 +77,24 @@ if __name__ == "__main__":
|
|||||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||||
|
|
||||||
|
# Figure out what kind of distributed generation we should do
|
||||||
group = mx.distributed.init()
|
group = mx.distributed.init()
|
||||||
n_images = args.n_images
|
n_images = args.n_images
|
||||||
should_gather = False
|
should_gather = False
|
||||||
if group.size() > 1:
|
if group.size() > 1:
|
||||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||||
flux.flow.shard(group)
|
flux.flow.shard(group)
|
||||||
if args.seed is None:
|
|
||||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
|
||||||
else:
|
else:
|
||||||
n_images //= group.size()
|
n_images //= group.size()
|
||||||
should_gather = True
|
should_gather = True
|
||||||
|
|
||||||
|
# If we are sharding we should have the same seed and if we are doing
|
||||||
|
# data parallel generation we should have different seeds
|
||||||
|
if args.seed is None:
|
||||||
|
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||||
|
if should_gather:
|
||||||
|
args.seed = args.seed + group.rank()
|
||||||
|
|
||||||
if args.preload_models:
|
if args.preload_models:
|
||||||
flux.ensure_models_are_loaded()
|
flux.ensure_models_are_loaded()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user