mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix the launcher when ran locally (#2147)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							e496c5a4b4
						
					
				
				
					commit
					a3a632d567
				
			| @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): | ||||
|  | ||||
|         # Repeat the stdout and stderr to the local machine | ||||
|         to_read = [p.stdout.fileno(), p.stderr.fileno()] | ||||
|         to_write = [p.stdin.fileno()] | ||||
|         to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] | ||||
|         pidfile = "" | ||||
|         stdin_buffer = b"" | ||||
|         stdout_buffer = b"" | ||||
|         stderr_buffer = b"" | ||||
|         while p.poll() is None: | ||||
|             try: | ||||
|                 stdin_buffer += input_queue.get_nowait() | ||||
| @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): | ||||
|                 pass | ||||
|             rlist, wlist, _ = select(to_read, to_write, [], 1.0) | ||||
|             for fd in rlist: | ||||
|                 is_stdout = fd == p.stdout.fileno() | ||||
|                 outfile = sys.stdout if is_stdout else sys.stderr | ||||
|                 msg = os.read(fd, 8192).decode(errors="ignore") | ||||
|  | ||||
|                 # Fetch the PID file first if we haven't already | ||||
| @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): | ||||
|                     pidfile, *msg = msg.split("\n", maxsplit=1) | ||||
|                     msg = msg[0] if msg else "" | ||||
|  | ||||
|                 outfile.write(msg) | ||||
|                 outfile.flush() | ||||
|                 is_stdout = fd == p.stdout.fileno() | ||||
|                 if is_stdout: | ||||
|                     stdout_buffer += msg.encode() | ||||
|                 else: | ||||
|                     stderr_buffer += msg.encode() | ||||
|             for fd in wlist: | ||||
|                 if len(stdin_buffer) > 0: | ||||
|                 if fd == p.stdin.fileno() and len(stdin_buffer) > 0: | ||||
|                     n = os.write(fd, stdin_buffer) | ||||
|                     stdin_buffer = stdin_buffer[n:] | ||||
|                 elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: | ||||
|                     n = os.write(fd, stdout_buffer) | ||||
|                     stdout_buffer = stdout_buffer[n:] | ||||
|                 elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: | ||||
|                     n = os.write(fd, stderr_buffer) | ||||
|                     stderr_buffer = stderr_buffer[n:] | ||||
|             if stop: | ||||
|                 p.terminate() | ||||
|                 break | ||||
|   | ||||
		Reference in New Issue
	
	Block a user