Remove python from the launch script

This commit is contained in:
Angelos Katharopoulos
2025-12-09 13:04:37 -08:00
parent 405d30b6e5
commit 9d707ba3b5

View File

@@ -47,13 +47,9 @@ class CommandProcess:
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, python, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = (
f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
)
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
if not is_local:
cmd = f"ssh {host} '{cmd}'"
cmd = f"ssh {host} {shlex.quote(cmd)}"
self._host = host
self._pidfile = None
@@ -92,47 +88,33 @@ class RemoteProcess(CommandProcess):
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
cmd = RemoteProcess.make_kill_script(self._pidfile)
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
def make_launch_script(rank, cwd, files, env, command):
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
script += "pidfile=$(mktemp); "
script += "echo $$ > $pidfile; "
script += "echo $pidfile; "
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
script += f"if [[ -d {repr(d)} ]]; then "
script += f" cd {repr(d)}; "
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
script += "else "
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
script += "fi; "
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
@@ -141,22 +123,34 @@ class RemoteProcess(CommandProcess):
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
script += f"export {key}={value}; "
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
script += "fname=$(mktemp); "
script += f"echo {shlex.quote(content)} >$fname; "
script += f"export {env_name}=$fname; "
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
script += f"export MLX_RANK={rank}; "
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
script += 'exec "${cmd[@]}"'
return script
@staticmethod
def make_kill_script(pidfile):
script = ""
script += f"pid=$(cat {pidfile}); "
script += "if ps -p $pid >/dev/null; then "
script += " kill $pid; "
script += " echo 1; "
script += "else "
script += " echo 0; "
script += "fi; "
script += f"rm {pidfile}"
return script