mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Fix the launcher when ran locally (#2147)
This commit is contained in:
parent
e496c5a4b4
commit
a3a632d567
@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
|
|
||||||
# Repeat the stdout and stderr to the local machine
|
# Repeat the stdout and stderr to the local machine
|
||||||
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||||
to_write = [p.stdin.fileno()]
|
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
|
||||||
pidfile = ""
|
pidfile = ""
|
||||||
stdin_buffer = b""
|
stdin_buffer = b""
|
||||||
|
stdout_buffer = b""
|
||||||
|
stderr_buffer = b""
|
||||||
while p.poll() is None:
|
while p.poll() is None:
|
||||||
try:
|
try:
|
||||||
stdin_buffer += input_queue.get_nowait()
|
stdin_buffer += input_queue.get_nowait()
|
||||||
@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||||
for fd in rlist:
|
for fd in rlist:
|
||||||
is_stdout = fd == p.stdout.fileno()
|
|
||||||
outfile = sys.stdout if is_stdout else sys.stderr
|
|
||||||
msg = os.read(fd, 8192).decode(errors="ignore")
|
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||||
|
|
||||||
# Fetch the PID file first if we haven't already
|
# Fetch the PID file first if we haven't already
|
||||||
@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
pidfile, *msg = msg.split("\n", maxsplit=1)
|
pidfile, *msg = msg.split("\n", maxsplit=1)
|
||||||
msg = msg[0] if msg else ""
|
msg = msg[0] if msg else ""
|
||||||
|
|
||||||
outfile.write(msg)
|
is_stdout = fd == p.stdout.fileno()
|
||||||
outfile.flush()
|
if is_stdout:
|
||||||
|
stdout_buffer += msg.encode()
|
||||||
|
else:
|
||||||
|
stderr_buffer += msg.encode()
|
||||||
for fd in wlist:
|
for fd in wlist:
|
||||||
if len(stdin_buffer) > 0:
|
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
|
||||||
n = os.write(fd, stdin_buffer)
|
n = os.write(fd, stdin_buffer)
|
||||||
stdin_buffer = stdin_buffer[n:]
|
stdin_buffer = stdin_buffer[n:]
|
||||||
|
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
|
||||||
|
n = os.write(fd, stdout_buffer)
|
||||||
|
stdout_buffer = stdout_buffer[n:]
|
||||||
|
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
|
||||||
|
n = os.write(fd, stderr_buffer)
|
||||||
|
stderr_buffer = stderr_buffer[n:]
|
||||||
if stop:
|
if stop:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user