allow command (#1836)

This commit is contained in:
Awni Hannun 2025-02-06 10:32:24 -08:00 committed by GitHub
parent a62fc1b39f
commit 83a0340fa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import ipaddress
import json import json
import os import os
import shlex import shlex
import shutil
import sys import sys
import tempfile import tempfile
import threading import threading
@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
script += "\n" script += "\n"
# Replace the process with the script # Replace the process with the script
script += shlex.join(["exec", sys.executable, *command]) script += shlex.join(["exec", *command])
script += "\n" script += "\n"
return script return script
@ -210,7 +211,7 @@ def launch_ring(parser, hosts, args, command):
ring_hosts.append(node) ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else "" hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join([sys.executable, *command])) log(args.verbose, "Running", shlex.join(command))
threads = [] threads = []
for i, h in enumerate(hosts): for i, h in enumerate(hosts):
@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command):
*sum((["-x", e] for e in args.env), []), *sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []), *sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--", "--",
sys.executable,
*command, *command,
] ]
log(args.verbose, "Running", " ".join(cmd)) log(args.verbose, "Running", " ".join(cmd))
@ -323,9 +323,12 @@ def main():
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts) hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path # Check if the script is a file and convert it to a full path
script = Path(rest[0]) if (script := Path(rest[0])).exists():
if script.exists(): rest[0:1] = [sys.executable, str(script.resolve())]
rest[0] = str(script.resolve()) elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch # Launch
if args.backend == "ring": if args.backend == "ring":