mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix the launcher when ran locally (#2147)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							e496c5a4b4
						
					
				
				
					commit
					a3a632d567
				
			| @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): | |||||||
|  |  | ||||||
|         # Repeat the stdout and stderr to the local machine |         # Repeat the stdout and stderr to the local machine | ||||||
|         to_read = [p.stdout.fileno(), p.stderr.fileno()] |         to_read = [p.stdout.fileno(), p.stderr.fileno()] | ||||||
|         to_write = [p.stdin.fileno()] |         to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] | ||||||
|         pidfile = "" |         pidfile = "" | ||||||
|         stdin_buffer = b"" |         stdin_buffer = b"" | ||||||
|  |         stdout_buffer = b"" | ||||||
|  |         stderr_buffer = b"" | ||||||
|         while p.poll() is None: |         while p.poll() is None: | ||||||
|             try: |             try: | ||||||
|                 stdin_buffer += input_queue.get_nowait() |                 stdin_buffer += input_queue.get_nowait() | ||||||
| @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): | |||||||
|                 pass |                 pass | ||||||
|             rlist, wlist, _ = select(to_read, to_write, [], 1.0) |             rlist, wlist, _ = select(to_read, to_write, [], 1.0) | ||||||
|             for fd in rlist: |             for fd in rlist: | ||||||
|                 is_stdout = fd == p.stdout.fileno() |  | ||||||
|                 outfile = sys.stdout if is_stdout else sys.stderr |  | ||||||
|                 msg = os.read(fd, 8192).decode(errors="ignore") |                 msg = os.read(fd, 8192).decode(errors="ignore") | ||||||
|  |  | ||||||
|                 # Fetch the PID file first if we haven't already |                 # Fetch the PID file first if we haven't already | ||||||
| @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): | |||||||
|                     pidfile, *msg = msg.split("\n", maxsplit=1) |                     pidfile, *msg = msg.split("\n", maxsplit=1) | ||||||
|                     msg = msg[0] if msg else "" |                     msg = msg[0] if msg else "" | ||||||
|  |  | ||||||
|                 outfile.write(msg) |                 is_stdout = fd == p.stdout.fileno() | ||||||
|                 outfile.flush() |                 if is_stdout: | ||||||
|  |                     stdout_buffer += msg.encode() | ||||||
|  |                 else: | ||||||
|  |                     stderr_buffer += msg.encode() | ||||||
|             for fd in wlist: |             for fd in wlist: | ||||||
|                 if len(stdin_buffer) > 0: |                 if fd == p.stdin.fileno() and len(stdin_buffer) > 0: | ||||||
|                     n = os.write(fd, stdin_buffer) |                     n = os.write(fd, stdin_buffer) | ||||||
|                     stdin_buffer = stdin_buffer[n:] |                     stdin_buffer = stdin_buffer[n:] | ||||||
|  |                 elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: | ||||||
|  |                     n = os.write(fd, stdout_buffer) | ||||||
|  |                     stdout_buffer = stdout_buffer[n:] | ||||||
|  |                 elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: | ||||||
|  |                     n = os.write(fd, stderr_buffer) | ||||||
|  |                     stderr_buffer = stderr_buffer[n:] | ||||||
|             if stop: |             if stop: | ||||||
|                 p.terminate() |                 p.terminate() | ||||||
|                 break |                 break | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user