mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 15:47:24 +08:00
CUDA_VISIBLE_DEVICES to local rank
This commit is contained in:
parent
dadf8d9c93
commit
984cefb14d
@ -435,7 +435,7 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
env = base_env.copy()
|
env = base_env.copy()
|
||||||
env["MLX_RANK"] = str(rank)
|
env["MLX_RANK"] = str(rank)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||||
p = Popen(command, env=env)
|
p = Popen(command, env=env)
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user