mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 06:53:18 +08:00
nccl default for backend=any (#2528)
* nccl default for backend=any * check num gpus + ensure row contiguous for all reduce * comment
This commit is contained in:
@@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
|
||||
return ports
|
||||
|
||||
|
||||
def get_num_nvidia_gpus():
|
||||
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
||||
return len(result.stdout.strip().split("\n"))
|
||||
|
||||
|
||||
def extract_rings(hosts, index):
|
||||
def usable_port(i, j, used_ports):
|
||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||
@@ -421,14 +426,16 @@ def launch_nccl(parser, hosts, args, command):
|
||||
master_host = hosts[0].ips[0]
|
||||
|
||||
if master_host != "127.0.0.1":
|
||||
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||
raise ValueError("The NCCL backend only supports localhost for now.")
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
base_env = os.environ.copy()
|
||||
base_env.update(
|
||||
{
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"NCCL_DEBUG": base_env.get(
|
||||
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
||||
),
|
||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||
"NCCL_HOST_IP": master_host,
|
||||
"NCCL_PORT": str(master_port),
|
||||
@@ -436,11 +443,18 @@ def launch_nccl(parser, hosts, args, command):
|
||||
}
|
||||
)
|
||||
procs = []
|
||||
num_gpus = get_num_nvidia_gpus()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
||||
if args.repeat_hosts > num_gpus:
|
||||
raise RuntimeError("NCCL requires a separate GPU per process.")
|
||||
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
||||
mlx_rank = str(rank % args.repeat_hosts)
|
||||
env["MLX_RANK"] = mlx_rank
|
||||
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
@@ -821,8 +835,6 @@ def main():
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
|
||||
if args.print_python:
|
||||
print(sys.executable)
|
||||
|
||||
Reference in New Issue
Block a user