diff --git a/python/mlx/_distributed_utils/launch.py b/python/mlx/_distributed_utils/launch.py index ee0fe6264..6bddeb4ca 100644 --- a/python/mlx/_distributed_utils/launch.py +++ b/python/mlx/_distributed_utils/launch.py @@ -47,13 +47,9 @@ class CommandProcess: class RemoteProcess(CommandProcess): def __init__(self, rank, host, python, cwd, files, env, command): is_local = host == "127.0.0.1" - script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command) - script_b64 = base64.b64encode(script.encode()).decode() - cmd = ( - f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' - ) + cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command) if not is_local: - cmd = f"ssh {host} '{cmd}'" + cmd = f"ssh {host} {shlex.quote(cmd)}" self._host = host self._pidfile = None @@ -92,47 +88,33 @@ class RemoteProcess(CommandProcess): self._process.wait() # Kill the remote program if possible - cmd = "" - cmd += f"pid=$(cat {self._pidfile}); " - cmd += "if ps -p $pid >/dev/null; then " - cmd += " kill $pid; " - cmd += " echo 1; " - cmd += "else " - cmd += " echo 0; " - cmd += "fi; " - cmd += f"rm {self._pidfile}" + cmd = RemoteProcess.make_kill_script(self._pidfile) if not self._is_local: - cmd = f"ssh {self._host} '{cmd}'" + cmd = f"ssh {self._host} {shlex.quote(cmd)}" c = run(cmd, check=True, shell=True, capture_output=True, text=True) self._killed = c.stdout.strip() == "1" @staticmethod - def make_monitor_script(rank, cwd, files, env, command): - # Imports that are used throughout + def make_launch_script(rank, cwd, files, env, command): script = "" - script += "import os\n" - script += "import sys\n" - script += "import tempfile\n" - script += "from pathlib import Path\n" # Write the PID to a file so we can kill the process if needed - script += "_, pidfile = tempfile.mkstemp() \n" - script += "open(pidfile, 'w').write(str(os.getpid()))\n" - script += "print(pidfile, flush=True)\n" + script += "pidfile=$(mktemp); " + script += "echo $$ > $pidfile; " + script += "echo $pidfile; " # Change the working directory if one was requested. Otherwise attempt to # change to the current one but don't fail if it wasn't possible. d = cwd or os.getcwd() - script += f"if Path({repr(d)}).exists():\n" - script += f" os.chdir({repr(d)})\n" + script += f"if [[ -d {repr(d)} ]]; then " + script += f" cd {repr(d)}; " if cwd is not None: - script += "else:\n" - script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n" - script += f" sys.exit(1)\n" + script += "else " + script += f" echo 'Failed to change directory to' {repr(d)} >2; " + script += "fi; " # Add the environment variables that were requested - script += "env = dict(os.environ)\n" for e in env: key, *value = e.split("=", maxsplit=1) value = shlex.quote(value[0]) if len(value) > 0 else "" @@ -141,22 +123,34 @@ class RemoteProcess(CommandProcess): f"'{e}' is an invalid environment variable so it is ignored" ) continue - script += f"env[{repr(key)}] = {repr(value)}\n" + script += f"export {key}={value}; " # Make the temporary files for env_name, content in files.items(): - script += "_, fname = tempfile.mkstemp()\n" - script += "with open(fname, 'w') as f:\n" - script += f" f.write({repr(content)})\n" - script += f"env[{repr(env_name)}] = fname\n" + script += "fname=$(mktemp); " + script += f"echo {shlex.quote(content)} >$fname; " + script += f"export {env_name}=$fname; " # Finally add the rank - script += f"env['MLX_RANK'] = '{rank}'\n" - script += "\n" + script += f"export MLX_RANK={rank}; " # Replace the process with the script - script += f"command = [{','.join(map(repr, command))}]\n" - script += "os.execve(command[0], command, env)\n" + script += f"cmd=({' '.join(map(shlex.quote, command))}); " + script += 'exec "${cmd[@]}"' + + return script + + @staticmethod + def make_kill_script(pidfile): + script = "" + script += f"pid=$(cat {pidfile}); " + script += "if ps -p $pid >/dev/null; then " + script += " kill $pid; " + script += " echo 1; " + script += "else " + script += " echo 0; " + script += "fi; " + script += f"rm {pidfile}" return script