mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-24 22:36:39 +08:00
nccl default for backend=any (#2528)
* nccl default for backend=any * check num gpus + ensure row contiguous for all reduce * comment
This commit is contained in:
parent
5722c147de
commit
068a4612e9
@ -405,6 +405,7 @@ jobs:
|
|||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libnccl2 libnccl-dev
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install zip
|
sudo apt-get install zip
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
|
@ -2,30 +2,35 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core::distributed {
|
||||||
namespace distributed {
|
|
||||||
void AllReduce::eval_gpu(
|
void AllReduce::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto& input = inputs[0];
|
auto set_input_output =
|
||||||
auto& output = outputs[0];
|
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
copy_gpu(in, out, CopyType::General, s);
|
||||||
|
return {out, out};
|
||||||
|
} else if (in.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return {in, out};
|
||||||
|
} else {
|
||||||
|
return {in, out};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
|
||||||
if (input.is_donatable()) {
|
|
||||||
output.copy_shared_buffer(input);
|
|
||||||
} else {
|
|
||||||
output.set_data(allocator::malloc(output.nbytes()));
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.set_input_array(input);
|
encoder.set_input_array(input);
|
||||||
encoder.set_output_array(output);
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
@ -47,5 +52,4 @@ void AllReduce::eval_gpu(
|
|||||||
"Only all reduce sum, max, and min are supported.");
|
"Only all reduce sum, max, and min are supported.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace distributed
|
} // namespace mlx::core::distributed
|
||||||
} // namespace mlx::core
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
@ -114,7 +115,7 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the requested communication group
|
// Create the requested communication group
|
||||||
std::shared_ptr<detail::GroupImpl> group;
|
std::shared_ptr<detail::GroupImpl> group{nullptr};
|
||||||
std::string bk_ = bk;
|
std::string bk_ = bk;
|
||||||
if (bk == "mpi") {
|
if (bk == "mpi") {
|
||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
@ -123,8 +124,14 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
|
if (mlx::core::cu::is_available()) {
|
||||||
|
group = nccl::init(false);
|
||||||
|
bk_ = "nccl";
|
||||||
|
}
|
||||||
|
if (group == nullptr) {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
}
|
||||||
if (group == nullptr) {
|
if (group == nullptr) {
|
||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
|
@ -204,14 +204,18 @@ inline void bootstrap_unique_id(
|
|||||||
int attempt = 0;
|
int attempt = 0;
|
||||||
bool connected = false;
|
bool connected = false;
|
||||||
|
|
||||||
|
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
0) {
|
0) {
|
||||||
connected = true;
|
connected = true;
|
||||||
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
if (do_log) {
|
||||||
<< attempt + 1 << std::endl;
|
std::cout << "[Rank " << rank
|
||||||
|
<< "] Connected successfully on attempt " << attempt + 1
|
||||||
|
<< std::endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (errno != ECONNREFUSED) {
|
if (errno != ECONNREFUSED) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -331,24 +335,33 @@ bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
static std::string get_env_var_or_throw(const char* env_var_name) {
|
std::string get_env_var_or_throw(const char* env_var_name, bool strict) {
|
||||||
const char* value = std::getenv(env_var_name);
|
const char* value = std::getenv(env_var_name);
|
||||||
if (value == nullptr) {
|
if (value == nullptr && strict) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[nccl] Required environment variable '" << env_var_name
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
<< "' is not set. "
|
<< "' is not set. "
|
||||||
<< "Please set it before initializing the distributed backend.";
|
<< "Please set it before initializing the distributed backend.";
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
if (value == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
return std::string(value);
|
return std::string(value);
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP", strict);
|
||||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT", strict);
|
||||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK", strict);
|
||||||
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
std::string n_nodes_str =
|
||||||
|
detail::get_env_var_or_throw("MLX_WORLD_SIZE", strict);
|
||||||
|
if (!strict &&
|
||||||
|
(host.empty() || port.empty() || rank_str.empty() ||
|
||||||
|
n_nodes_str.empty())) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
int rank = std::stoi(rank_str);
|
int rank = std::stoi(rank_str);
|
||||||
int n_nodes = std::stoi(n_nodes_str);
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
@ -55,6 +55,11 @@ def parse_hardware_ports(ports_string):
|
|||||||
return ports
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_nvidia_gpus():
|
||||||
|
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
|
||||||
|
return len(result.stdout.strip().split("\n"))
|
||||||
|
|
||||||
|
|
||||||
def extract_rings(hosts, index):
|
def extract_rings(hosts, index):
|
||||||
def usable_port(i, j, used_ports):
|
def usable_port(i, j, used_ports):
|
||||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||||
@ -421,14 +426,16 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
master_host = hosts[0].ips[0]
|
master_host = hosts[0].ips[0]
|
||||||
|
|
||||||
if master_host != "127.0.0.1":
|
if master_host != "127.0.0.1":
|
||||||
raise ValueError("The NCCL backend only supports localhost for now. ")
|
raise ValueError("The NCCL backend only supports localhost for now.")
|
||||||
master_port = args.nccl_port
|
master_port = args.nccl_port
|
||||||
world_size = len(hosts)
|
world_size = len(hosts)
|
||||||
|
|
||||||
base_env = os.environ.copy()
|
base_env = os.environ.copy()
|
||||||
base_env.update(
|
base_env.update(
|
||||||
{
|
{
|
||||||
"NCCL_DEBUG": "INFO",
|
"NCCL_DEBUG": base_env.get(
|
||||||
|
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
|
||||||
|
),
|
||||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||||
"NCCL_HOST_IP": master_host,
|
"NCCL_HOST_IP": master_host,
|
||||||
"NCCL_PORT": str(master_port),
|
"NCCL_PORT": str(master_port),
|
||||||
@ -436,11 +443,18 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
procs = []
|
procs = []
|
||||||
|
num_gpus = get_num_nvidia_gpus()
|
||||||
|
if num_gpus == 0:
|
||||||
|
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
|
||||||
|
if args.repeat_hosts > num_gpus:
|
||||||
|
raise RuntimeError("NCCL requires a separate GPU per process.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for rank in range(world_size):
|
for rank in range(world_size):
|
||||||
env = base_env.copy()
|
env = base_env.copy()
|
||||||
env["MLX_RANK"] = str(rank % args.repeat_hosts)
|
mlx_rank = str(rank % args.repeat_hosts)
|
||||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.repeat_hosts)
|
env["MLX_RANK"] = mlx_rank
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
|
||||||
p = Popen(command, env=env)
|
p = Popen(command, env=env)
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
@ -821,8 +835,6 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
|
||||||
rest.pop(0)
|
|
||||||
|
|
||||||
if args.print_python:
|
if args.print_python:
|
||||||
print(sys.executable)
|
print(sys.executable)
|
||||||
|
Loading…
Reference in New Issue
Block a user