mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove python from the launch script
This commit is contained in:
@@ -47,13 +47,9 @@ class CommandProcess:
|
|||||||
class RemoteProcess(CommandProcess):
|
class RemoteProcess(CommandProcess):
|
||||||
def __init__(self, rank, host, python, cwd, files, env, command):
|
def __init__(self, rank, host, python, cwd, files, env, command):
|
||||||
is_local = host == "127.0.0.1"
|
is_local = host == "127.0.0.1"
|
||||||
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
|
||||||
script_b64 = base64.b64encode(script.encode()).decode()
|
|
||||||
cmd = (
|
|
||||||
f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
|
||||||
)
|
|
||||||
if not is_local:
|
if not is_local:
|
||||||
cmd = f"ssh {host} '{cmd}'"
|
cmd = f"ssh {host} {shlex.quote(cmd)}"
|
||||||
|
|
||||||
self._host = host
|
self._host = host
|
||||||
self._pidfile = None
|
self._pidfile = None
|
||||||
@@ -92,47 +88,33 @@ class RemoteProcess(CommandProcess):
|
|||||||
self._process.wait()
|
self._process.wait()
|
||||||
|
|
||||||
# Kill the remote program if possible
|
# Kill the remote program if possible
|
||||||
cmd = ""
|
cmd = RemoteProcess.make_kill_script(self._pidfile)
|
||||||
cmd += f"pid=$(cat {self._pidfile}); "
|
|
||||||
cmd += "if ps -p $pid >/dev/null; then "
|
|
||||||
cmd += " kill $pid; "
|
|
||||||
cmd += " echo 1; "
|
|
||||||
cmd += "else "
|
|
||||||
cmd += " echo 0; "
|
|
||||||
cmd += "fi; "
|
|
||||||
cmd += f"rm {self._pidfile}"
|
|
||||||
if not self._is_local:
|
if not self._is_local:
|
||||||
cmd = f"ssh {self._host} '{cmd}'"
|
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
|
||||||
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||||
|
|
||||||
self._killed = c.stdout.strip() == "1"
|
self._killed = c.stdout.strip() == "1"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_monitor_script(rank, cwd, files, env, command):
|
def make_launch_script(rank, cwd, files, env, command):
|
||||||
# Imports that are used throughout
|
|
||||||
script = ""
|
script = ""
|
||||||
script += "import os\n"
|
|
||||||
script += "import sys\n"
|
|
||||||
script += "import tempfile\n"
|
|
||||||
script += "from pathlib import Path\n"
|
|
||||||
|
|
||||||
# Write the PID to a file so we can kill the process if needed
|
# Write the PID to a file so we can kill the process if needed
|
||||||
script += "_, pidfile = tempfile.mkstemp() \n"
|
script += "pidfile=$(mktemp); "
|
||||||
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
script += "echo $$ > $pidfile; "
|
||||||
script += "print(pidfile, flush=True)\n"
|
script += "echo $pidfile; "
|
||||||
|
|
||||||
# Change the working directory if one was requested. Otherwise attempt to
|
# Change the working directory if one was requested. Otherwise attempt to
|
||||||
# change to the current one but don't fail if it wasn't possible.
|
# change to the current one but don't fail if it wasn't possible.
|
||||||
d = cwd or os.getcwd()
|
d = cwd or os.getcwd()
|
||||||
script += f"if Path({repr(d)}).exists():\n"
|
script += f"if [[ -d {repr(d)} ]]; then "
|
||||||
script += f" os.chdir({repr(d)})\n"
|
script += f" cd {repr(d)}; "
|
||||||
if cwd is not None:
|
if cwd is not None:
|
||||||
script += "else:\n"
|
script += "else "
|
||||||
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
|
||||||
script += f" sys.exit(1)\n"
|
script += "fi; "
|
||||||
|
|
||||||
# Add the environment variables that were requested
|
# Add the environment variables that were requested
|
||||||
script += "env = dict(os.environ)\n"
|
|
||||||
for e in env:
|
for e in env:
|
||||||
key, *value = e.split("=", maxsplit=1)
|
key, *value = e.split("=", maxsplit=1)
|
||||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||||
@@ -141,22 +123,34 @@ class RemoteProcess(CommandProcess):
|
|||||||
f"'{e}' is an invalid environment variable so it is ignored"
|
f"'{e}' is an invalid environment variable so it is ignored"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
script += f"env[{repr(key)}] = {repr(value)}\n"
|
script += f"export {key}={value}; "
|
||||||
|
|
||||||
# Make the temporary files
|
# Make the temporary files
|
||||||
for env_name, content in files.items():
|
for env_name, content in files.items():
|
||||||
script += "_, fname = tempfile.mkstemp()\n"
|
script += "fname=$(mktemp); "
|
||||||
script += "with open(fname, 'w') as f:\n"
|
script += f"echo {shlex.quote(content)} >$fname; "
|
||||||
script += f" f.write({repr(content)})\n"
|
script += f"export {env_name}=$fname; "
|
||||||
script += f"env[{repr(env_name)}] = fname\n"
|
|
||||||
|
|
||||||
# Finally add the rank
|
# Finally add the rank
|
||||||
script += f"env['MLX_RANK'] = '{rank}'\n"
|
script += f"export MLX_RANK={rank}; "
|
||||||
script += "\n"
|
|
||||||
|
|
||||||
# Replace the process with the script
|
# Replace the process with the script
|
||||||
script += f"command = [{','.join(map(repr, command))}]\n"
|
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
|
||||||
script += "os.execve(command[0], command, env)\n"
|
script += 'exec "${cmd[@]}"'
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_kill_script(pidfile):
|
||||||
|
script = ""
|
||||||
|
script += f"pid=$(cat {pidfile}); "
|
||||||
|
script += "if ps -p $pid >/dev/null; then "
|
||||||
|
script += " kill $pid; "
|
||||||
|
script += " echo 1; "
|
||||||
|
script += "else "
|
||||||
|
script += " echo 0; "
|
||||||
|
script += "fi; "
|
||||||
|
script += f"rm {pidfile}"
|
||||||
|
|
||||||
return script
|
return script
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user