diff --git a/.circleci/config.yml b/.circleci/config.yml index 03987d39c..70f4e0fe5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -405,6 +405,7 @@ jobs: sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 + sudo apt-get install libnccl2 libnccl-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install zip pip install auditwheel diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 2cdf615f5..90eeacd7c 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -2,30 +2,35 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" #include "mlx/distributed/primitives.h" #include "mlx/primitives.h" #include -namespace mlx::core { -namespace distributed { +namespace mlx::core::distributed { void AllReduce::eval_gpu( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); - auto& input = inputs[0]; - auto& output = outputs[0]; + auto set_input_output = + [s = stream()](const array& in, array& out) -> std::pair { + if (!in.flags().row_contiguous) { + copy_gpu(in, out, CopyType::General, s); + return {out, out}; + } else if (in.is_donatable()) { + out.copy_shared_buffer(in); + return {in, out}; + } else { + return {in, out}; + } + }; + + auto [input, output] = set_input_output(inputs[0], outputs[0]); auto& encoder = cu::get_command_encoder(stream()); - - if (input.is_donatable()) { - output.copy_shared_buffer(input); - } else { - output.set_data(allocator::malloc(output.nbytes())); - } - encoder.set_input_array(input); encoder.set_output_array(output); @@ -47,5 +52,4 @@ void AllReduce::eval_gpu( "Only all reduce sum, max, and min are supported."); } } -} // namespace distributed -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core::distributed diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 44205e87e..d71ebb9b1 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/cuda/cuda.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" @@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { } // Create the requested communication group - std::shared_ptr group; + std::shared_ptr group{nullptr}; std::string bk_ = bk; if (bk == "mpi") { group = mpi::init(strict); @@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) { } else if (bk == "nccl") { group = nccl::init(strict); } else if (bk == "any") { - group = ring::init(false); - bk_ = "ring"; + if (mlx::core::cu::is_available()) { + group = nccl::init(false); + bk_ = "nccl"; + } + if (group == nullptr) { + group = ring::init(false); + bk_ = "ring"; + } if (group == nullptr) { group = mpi::init(false); bk_ = "mpi"; diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 43af9c724..751ba9130 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -204,13 +204,17 @@ inline void bootstrap_unique_id( int attempt = 0; bool connected = false; + bool do_log = std::getenv("NCCL_DEBUG") == "INFO"; for (attempt = 0; attempt < max_retries; ++attempt) { if (connect(sock, reinterpret_cast(&serv), sizeof(serv)) == 0) { connected = true; - std::cout << "[Rank " << rank << "] Connected successfully on attempt " - << attempt + 1 << std::endl; - break; + if (do_log) { + std::cout << "[Rank " << rank + << "] Connected successfully on attempt " << attempt + 1 + << std::endl; + break; + } } if (errno != ECONNREFUSED) { break; @@ -331,24 +335,33 @@ bool is_available() { } namespace detail { -static std::string get_env_var_or_throw(const char* env_var_name) { +std::string get_env_var_or_throw(const char* env_var_name, bool strict) { const char* value = std::getenv(env_var_name); - if (value == nullptr) { + if (value == nullptr && strict) { std::ostringstream msg; msg << "[nccl] Required environment variable '" << env_var_name << "' is not set. " << "Please set it before initializing the distributed backend."; throw std::runtime_error(msg.str()); } + if (value == nullptr) { + return ""; + } return std::string(value); } } // namespace detail std::shared_ptr init(bool strict /* = false */) { - std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP"); - std::string port = detail::get_env_var_or_throw("NCCL_PORT"); - std::string rank_str = detail::get_env_var_or_throw("MLX_RANK"); - std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE"); + std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict); + std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict); + std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict); + std::string n_nodes_str = + detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict); + if (!strict && + (host.empty() || port.empty() || rank_str.empty() || + n_nodes_str.empty())) { + return nullptr; + } int rank = std::stoi(rank_str); int n_nodes = std::stoi(n_nodes_str); diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 3dd373eb1..bb0e3c633 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string): return ports +def get_num_nvidia_gpus(): + result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True) + return len(result.stdout.strip().split("\n")) + + def extract_rings(hosts, index): def usable_port(i, j, used_ports): return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None @@ -421,14 +426,16 @@ def launch_nccl(parser, hosts, args, command): master_host = hosts[0].ips[0] if master_host != "127.0.0.1": - raise ValueError("The NCCL backend only supports localhost for now. ") + raise ValueError("The NCCL backend only supports localhost for now.") master_port = args.nccl_port world_size = len(hosts) base_env = os.environ.copy() base_env.update( { - "NCCL_DEBUG": "INFO", + "NCCL_DEBUG": base_env.get( + "NCCL_DEBUG", "INFO" if args.verbose else "DEBUG" + ), "NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication "NCCL_HOST_IP": master_host, "NCCL_PORT": str(master_port), @@ -436,11 +443,18 @@ def launch_nccl(parser, hosts, args, command): } ) procs = [] + num_gpus = get_num_nvidia_gpus() + if num_gpus == 0: + raise RuntimeError("Cannot run NCCL backend with no GPUs.") + if args.repeat_hosts > num_gpus: + raise RuntimeError("NCCL requires a separate GPU per process.") + try: for rank in range(world_size): env = base_env.copy() - env["MLX_RANK"] = str(rank % args.repeat_hosts) - env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts) + mlx_rank = str(rank % args.repeat_hosts) + env["MLX_RANK"] = mlx_rank + env["CUDA_VISIBLE_DEVICES"] = mlx_rank p = Popen(command, env=env) procs.append(p) @@ -821,8 +835,6 @@ def main(): ) args, rest = parser.parse_known_args() - if rest[0] == "--": - rest.pop(0) if args.print_python: print(sys.executable)