Expose per-backend availability in C++ and python

This commit is contained in:
Angelos Katharopoulos
2025-11-20 15:26:59 -08:00
parent 1216afdc91
commit bfdddd644b
4 changed files with 42 additions and 6 deletions

View File

@@ -107,6 +107,25 @@ bool is_available() {
ibv::is_available(); ibv::is_available();
} }
bool is_available(const std::string& bk) {
if (bk == "any") {
return is_available();
}
if (bk == "mpi") {
return mpi::is_available();
}
if (bk == "ring") {
return ring::is_available();
}
if (bk == "nccl") {
return nccl::is_available();
}
if (bk == "ibv") {
return ibv::is_available();
}
return false;
}
int Group::rank() const { int Group::rank() const {
return group_->rank(); return group_->rank();
} }

View File

@@ -16,6 +16,7 @@ class GroupImpl;
/* Check if a communication backend is available */ /* Check if a communication backend is available */
bool is_available(); bool is_available();
bool is_available(const std::string& bk);
/** /**
* A distributed::Group represents a group of independent mlx processes that * A distributed::Group represents a group of independent mlx processes that

View File

@@ -1088,7 +1088,7 @@ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
const char* rank_str = std::getenv("MLX_RANK"); const char* rank_str = std::getenv("MLX_RANK");
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE"); const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
if (!dev_file || !coordinator || !rank_str) { if (!is_available() || !dev_file || !coordinator || !rank_str) {
if (strict) { if (strict) {
std::ostringstream msg; std::ostringstream msg;
msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), " msg << "[ibv] You need to provide via environment variables a rank (MLX_RANK), "

View File

@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"is_available", "is_available",
&mx::distributed::is_available, [](const std::string& backend) {
return mx::distributed::is_available(backend);
},
"backend"_a = "any",
nb::sig("def is_available(backend: str = 'any') -> bool"),
R"pbdoc( R"pbdoc(
Check if a communication backend is available. Check if a communication backend is available.
Note, this function returns whether MLX has the capability of
instantiating that distributed backend not whether it is possible to
create a communication group. For that purpose one should use
``init(strict=True)``.
Args:
backend (str, optional): The name of the backend to check for availability.
It takes the same values as ``init()``. Default: ``any``.
Returns:
bool: Whether the distributed backend is available.
)pbdoc"); )pbdoc");
m.def( m.def(
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False`` it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize. backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all Possible values ``mpi``, ``ring``, ``nccl``, ``ibv``, ``any``. If
available backends are tried and the first one that succeeds set to ``any`` all available backends are tried and the first one
becomes the global group which will be returned in subsequent that succeeds becomes the global group which will be returned in
calls. Default: ``any`` subsequent calls. Default: ``any``
Returns: Returns:
Group: The group representing all the launched processes. Group: The group representing all the launched processes.