allow command (#1836)

This commit is contained in:
Awni Hannun 2025-02-06 10:32:24 -08:00 committed by GitHub
parent a62fc1b39f
commit 83a0340fa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import ipaddress
import json
import os
import shlex
import shutil
import sys
import tempfile
import threading
@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
script += "\n"
# Replace the process with the script
script += shlex.join(["exec", sys.executable, *command])
script += shlex.join(["exec", *command])
script += "\n"
return script
@ -210,7 +211,7 @@ def launch_ring(parser, hosts, args, command):
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join([sys.executable, *command]))
log(args.verbose, "Running", shlex.join(command))
threads = []
for i, h in enumerate(hosts):
@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command):
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
sys.executable,
*command,
]
log(args.verbose, "Running", " ".join(cmd))
@ -323,9 +323,12 @@ def main():
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
script = Path(rest[0])
if script.exists():
rest[0] = str(script.resolve())
if (script := Path(rest[0])).exists():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":