diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 828e22efb..0857e286e 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -6,6 +6,7 @@ import ipaddress import json import os import shlex +import shutil import sys import tempfile import threading @@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose): script += "\n" # Replace the process with the script - script += shlex.join(["exec", sys.executable, *command]) + script += shlex.join(["exec", *command]) script += "\n" return script @@ -210,7 +211,7 @@ def launch_ring(parser, hosts, args, command): ring_hosts.append(node) hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else "" - log(args.verbose, "Running", shlex.join([sys.executable, *command])) + log(args.verbose, "Running", shlex.join(command)) threads = [] for i, h in enumerate(hosts): @@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command): *sum((["-x", e] for e in args.env), []), *sum([shlex.split(arg) for arg in args.mpi_arg], []), "--", - sys.executable, *command, ] log(args.verbose, "Running", " ".join(cmd)) @@ -323,9 +323,12 @@ def main(): hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts) # Check if the script is a file and convert it to a full path - script = Path(rest[0]) - if script.exists(): - rest[0] = str(script.resolve()) + if (script := Path(rest[0])).exists(): + rest[0:1] = [sys.executable, str(script.resolve())] + elif (command := shutil.which(rest[0])) is not None: + rest[0] = command + else: + raise ValueError(f"Invalid script or command {rest[0]}") # Launch if args.backend == "ring":