CUDA_VISIBLE_DEVICES to local rank

This commit is contained in:
Anastasiia Filippova 2025-08-09 01:43:14 +02:00
parent dadf8d9c93
commit 984cefb14d

View File

@ -435,7 +435,7 @@ def launch_nccl(parser, hosts, args, command):
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
p = Popen(command, env=env)
procs.append(p)