This commit is contained in:
Angelos Katharopoulos
2025-02-28 11:34:21 -08:00
committed by GitHub
parent 607181644f
commit 5d68082881
6 changed files with 378 additions and 74 deletions

View File

@@ -297,7 +297,7 @@ def launch_ring(parser, hosts, args, command):
"The ring backend requires IPs to be provided instead of hostnames"
)
port = 5000
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
@@ -669,6 +669,11 @@ def distributed_config():
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
@@ -707,11 +712,25 @@ def main():
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)

View File

@@ -2,4 +2,4 @@
from mlx.nn import init, losses
from mlx.nn.layers import *
from mlx.nn.utils import value_and_grad
from mlx.nn.utils import average_gradients, value_and_grad