mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 18:51:14 +08:00
allow command (#1836)
This commit is contained in:
parent
a62fc1b39f
commit
83a0340fa7
@ -6,6 +6,7 @@ import ipaddress
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
|||||||
script += "\n"
|
script += "\n"
|
||||||
|
|
||||||
# Replace the process with the script
|
# Replace the process with the script
|
||||||
script += shlex.join(["exec", sys.executable, *command])
|
script += shlex.join(["exec", *command])
|
||||||
script += "\n"
|
script += "\n"
|
||||||
|
|
||||||
return script
|
return script
|
||||||
@ -210,7 +211,7 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
ring_hosts.append(node)
|
ring_hosts.append(node)
|
||||||
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||||
|
|
||||||
log(args.verbose, "Running", shlex.join([sys.executable, *command]))
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
for i, h in enumerate(hosts):
|
for i, h in enumerate(hosts):
|
||||||
@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
*sum((["-x", e] for e in args.env), []),
|
*sum((["-x", e] for e in args.env), []),
|
||||||
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||||
"--",
|
"--",
|
||||||
sys.executable,
|
|
||||||
*command,
|
*command,
|
||||||
]
|
]
|
||||||
log(args.verbose, "Running", " ".join(cmd))
|
log(args.verbose, "Running", " ".join(cmd))
|
||||||
@ -323,9 +323,12 @@ def main():
|
|||||||
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||||
|
|
||||||
# Check if the script is a file and convert it to a full path
|
# Check if the script is a file and convert it to a full path
|
||||||
script = Path(rest[0])
|
if (script := Path(rest[0])).exists():
|
||||||
if script.exists():
|
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||||
rest[0] = str(script.resolve())
|
elif (command := shutil.which(rest[0])) is not None:
|
||||||
|
rest[0] = command
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||||
|
|
||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
|
Loading…
Reference in New Issue
Block a user