From f40152ebc1d3b6e45227ae82ded2ab7006baeb59 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 20 Nov 2025 15:26:59 -0800 Subject: [PATCH] Expose per-backend availability in C++ and python --- mlx/distributed/distributed.cpp | 19 +++++++++++++++++++ mlx/distributed/distributed.h | 1 + mlx/distributed/ibv/ibv.cpp | 2 +- python/src/distributed.cpp | 26 +++++++++++++++++++++----- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 74637edb5..0600e6a36 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -107,6 +107,25 @@ bool is_available() { ibv::is_available(); } +bool is_available(const std::string& bk) { + if (bk == "any") { + return is_available(); + } + if (bk == "mpi") { + return mpi::is_available(); + } + if (bk == "ring") { + return ring::is_available(); + } + if (bk == "nccl") { + return nccl::is_available(); + } + if (bk == "ibv") { + return ibv::is_available(); + } + return false; +} + int Group::rank() const { return group_->rank(); } diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index fa5c42a1f..a6971dd97 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -16,6 +16,7 @@ class GroupImpl; /* Check if a communication backend is available */ bool is_available(); +bool is_available(const std::string& bk); /** * A distributed::Group represents a group of independent mlx processes that diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index 447ba7d2e..3a52105da 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -1088,7 +1088,7 @@ std::shared_ptr init(bool strict /* = false */) { const char* rank_str = std::getenv("MLX_RANK"); const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE"); - if (!dev_file || !coordinator || !rank_str) { + if (!is_available() || !dev_file || !coordinator || !rank_str) { if (strict) { std::ostringstream msg; msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), " diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index d147c2783..3c6f587f3 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) { m.def( "is_available", - &mx::distributed::is_available, + [](const std::string& backend) { + return mx::distributed::is_available(backend); + }, + "backend"_a = "any", + nb::sig("def is_available(backend: str = 'any') -> bool"), R"pbdoc( Check if a communication backend is available. + + Note, this function returns whether MLX has the capability of + instantiating that distributed backend not whether it is possible to + create a communication group. For that purpose one should use + ``init(strict=True)``. + + Args: + backend (str, optional): The name of the backend to check for availability. + It takes the same values as ``init()``. Default: ``any``. + + Returns: + bool: Whether the distributed backend is available. )pbdoc"); m.def( @@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) { in case ``mx.distributed.is_available()`` returns False otherwise it throws a runtime error. Default: ``False`` backend (str, optional): Which distributed backend to initialize. - Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all - available backends are tried and the first one that succeeds - becomes the global group which will be returned in subsequent - calls. Default: ``any`` + Possible values ``mpi``, ``ring``, ``nccl``, ``ibv``, ``any``. If + set to ``any`` all available backends are tried and the first one + that succeeds becomes the global group which will be returned in + subsequent calls. Default: ``any`` Returns: Group: The group representing all the launched processes.