mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove python from the launch script
This commit is contained in:
@@ -47,13 +47,9 @@ class CommandProcess:
|
||||
class RemoteProcess(CommandProcess):
|
||||
def __init__(self, rank, host, python, cwd, files, env, command):
|
||||
is_local = host == "127.0.0.1"
|
||||
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
||||
script_b64 = base64.b64encode(script.encode()).decode()
|
||||
cmd = (
|
||||
f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||
)
|
||||
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
|
||||
if not is_local:
|
||||
cmd = f"ssh {host} '{cmd}'"
|
||||
cmd = f"ssh {host} {shlex.quote(cmd)}"
|
||||
|
||||
self._host = host
|
||||
self._pidfile = None
|
||||
@@ -92,47 +88,33 @@ class RemoteProcess(CommandProcess):
|
||||
self._process.wait()
|
||||
|
||||
# Kill the remote program if possible
|
||||
cmd = ""
|
||||
cmd += f"pid=$(cat {self._pidfile}); "
|
||||
cmd += "if ps -p $pid >/dev/null; then "
|
||||
cmd += " kill $pid; "
|
||||
cmd += " echo 1; "
|
||||
cmd += "else "
|
||||
cmd += " echo 0; "
|
||||
cmd += "fi; "
|
||||
cmd += f"rm {self._pidfile}"
|
||||
cmd = RemoteProcess.make_kill_script(self._pidfile)
|
||||
if not self._is_local:
|
||||
cmd = f"ssh {self._host} '{cmd}'"
|
||||
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
|
||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||
|
||||
self._killed = c.stdout.strip() == "1"
|
||||
|
||||
@staticmethod
|
||||
def make_monitor_script(rank, cwd, files, env, command):
|
||||
# Imports that are used throughout
|
||||
def make_launch_script(rank, cwd, files, env, command):
|
||||
script = ""
|
||||
script += "import os\n"
|
||||
script += "import sys\n"
|
||||
script += "import tempfile\n"
|
||||
script += "from pathlib import Path\n"
|
||||
|
||||
# Write the PID to a file so we can kill the process if needed
|
||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||
script += "print(pidfile, flush=True)\n"
|
||||
script += "pidfile=$(mktemp); "
|
||||
script += "echo $$ > $pidfile; "
|
||||
script += "echo $pidfile; "
|
||||
|
||||
# Change the working directory if one was requested. Otherwise attempt to
|
||||
# change to the current one but don't fail if it wasn't possible.
|
||||
d = cwd or os.getcwd()
|
||||
script += f"if Path({repr(d)}).exists():\n"
|
||||
script += f" os.chdir({repr(d)})\n"
|
||||
script += f"if [[ -d {repr(d)} ]]; then "
|
||||
script += f" cd {repr(d)}; "
|
||||
if cwd is not None:
|
||||
script += "else:\n"
|
||||
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||
script += f" sys.exit(1)\n"
|
||||
script += "else "
|
||||
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
|
||||
script += "fi; "
|
||||
|
||||
# Add the environment variables that were requested
|
||||
script += "env = dict(os.environ)\n"
|
||||
for e in env:
|
||||
key, *value = e.split("=", maxsplit=1)
|
||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||
@@ -141,22 +123,34 @@ class RemoteProcess(CommandProcess):
|
||||
f"'{e}' is an invalid environment variable so it is ignored"
|
||||
)
|
||||
continue
|
||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||
script += f"export {key}={value}; "
|
||||
|
||||
# Make the temporary files
|
||||
for env_name, content in files.items():
|
||||
script += "_, fname = tempfile.mkstemp()\n"
|
||||
script += "with open(fname, 'w') as f:\n"
|
||||
script += f" f.write({repr(content)})\n"
|
||||
script += f"env[{repr(env_name)}] = fname\n"
|
||||
script += "fname=$(mktemp); "
|
||||
script += f"echo {shlex.quote(content)} >$fname; "
|
||||
script += f"export {env_name}=$fname; "
|
||||
|
||||
# Finally add the rank
|
||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||
script += "\n"
|
||||
script += f"export MLX_RANK={rank}; "
|
||||
|
||||
# Replace the process with the script
|
||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||
script += "os.execve(command[0], command, env)\n"
|
||||
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
|
||||
script += 'exec "${cmd[@]}"'
|
||||
|
||||
return script
|
||||
|
||||
@staticmethod
|
||||
def make_kill_script(pidfile):
|
||||
script = ""
|
||||
script += f"pid=$(cat {pidfile}); "
|
||||
script += "if ps -p $pid >/dev/null; then "
|
||||
script += " kill $pid; "
|
||||
script += " echo 1; "
|
||||
script += "else "
|
||||
script += " echo 0; "
|
||||
script += "fi; "
|
||||
script += f"rm {pidfile}"
|
||||
|
||||
return script
|
||||
|
||||
|
||||
Reference in New Issue
Block a user