This commit is contained in:
Angelos Katharopoulos
2025-02-28 11:34:21 -08:00
committed by GitHub
parent 607181644f
commit 5d68082881
6 changed files with 378 additions and 74 deletions

View File

@@ -297,7 +297,7 @@ def launch_ring(parser, hosts, args, command):
"The ring backend requires IPs to be provided instead of hostnames"
)
port = 5000
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
@@ -669,6 +669,11 @@ def distributed_config():
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
@@ -707,11 +712,25 @@ def main():
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)

View File

@@ -2,4 +2,4 @@
from mlx.nn import init, losses
from mlx.nn.layers import *
from mlx.nn.utils import value_and_grad
from mlx.nn.utils import average_gradients, value_and_grad

View File

@@ -68,19 +68,21 @@ void init_distributed(nb::module_& parent_module) {
Example:
import mlx.core as mx
.. code:: python
group = mx.distributed.init(backend="ring")
import mlx.core as mx
group = mx.distributed.init(backend="ring")
Args:
strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Select a specific distributed backend to
initialize. If set to ``any`` then try all available backends and
return the first one that succeeds. Subsequent calls will return
the first backend that was initialized. Default: ``any``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``
Returns:
Group: The group representing all the launched processes.