mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 02:25:33 +08:00
repeat host -> proc per node
This commit is contained in:
parent
389276e2b8
commit
dadf8d9c93
@ -418,7 +418,7 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
def launch_nccl(parser, hosts, args, command):
|
def launch_nccl(parser, hosts, args, command):
|
||||||
master_host = hosts[0].ips[0]
|
master_host = hosts[0].ips[0]
|
||||||
master_port = args.nccl_port
|
master_port = args.nccl_port
|
||||||
world_size = args.repeat_hosts * len(hosts)
|
world_size = args.nproc_per_node * len(hosts)
|
||||||
|
|
||||||
base_env = os.environ.copy()
|
base_env = os.environ.copy()
|
||||||
base_env.update(
|
base_env.update(
|
||||||
@ -814,6 +814,12 @@ def main():
|
|||||||
default=12345,
|
default=12345,
|
||||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc-per-node",
|
||||||
|
type=positive_number,
|
||||||
|
default=1,
|
||||||
|
help="How many processes to run per node (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
if rest[0] == "--":
|
||||||
|
Loading…
Reference in New Issue
Block a user