mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA_VISIBLE_DEVICES to local rank
This commit is contained in:
@@ -435,7 +435,7 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
env = base_env.copy()
|
env = base_env.copy()
|
||||||
env["MLX_RANK"] = str(rank)
|
env["MLX_RANK"] = str(rank)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||||
p = Popen(command, env=env)
|
p = Popen(command, env=env)
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user