mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
allow command (#1836)
This commit is contained in:
parent
a62fc1b39f
commit
83a0340fa7
@ -6,6 +6,7 @@ import ipaddress
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
||||
script += "\n"
|
||||
|
||||
# Replace the process with the script
|
||||
script += shlex.join(["exec", sys.executable, *command])
|
||||
script += shlex.join(["exec", *command])
|
||||
script += "\n"
|
||||
|
||||
return script
|
||||
@ -210,7 +211,7 @@ def launch_ring(parser, hosts, args, command):
|
||||
ring_hosts.append(node)
|
||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||
|
||||
log(args.verbose, "Running", shlex.join([sys.executable, *command]))
|
||||
log(args.verbose, "Running", shlex.join(command))
|
||||
|
||||
threads = []
|
||||
for i, h in enumerate(hosts):
|
||||
@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command):
|
||||
*sum((["-x", e] for e in args.env), []),
|
||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||
"--",
|
||||
sys.executable,
|
||||
*command,
|
||||
]
|
||||
log(args.verbose, "Running", " ".join(cmd))
|
||||
@ -323,9 +323,12 @@ def main():
|
||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||
|
||||
# Check if the script is a file and convert it to a full path
|
||||
script = Path(rest[0])
|
||||
if script.exists():
|
||||
rest[0] = str(script.resolve())
|
||||
if (script := Path(rest[0])).exists():
|
||||
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||
elif (command := shutil.which(rest[0])) is not None:
|
||||
rest[0] = command
|
||||
else:
|
||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
|
Loading…
Reference in New Issue
Block a user