mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Deleted comments, renamed the function
This commit is contained in:
parent
70f2baf39f
commit
e6ae350999
@ -75,12 +75,12 @@ inline void recvAll(int sock, void* buf, size_t len) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void bootstrapUniqueId(
|
inline void bootstrap_unique_id(
|
||||||
ncclUniqueId& id,
|
ncclUniqueId& id,
|
||||||
int rank,
|
int rank,
|
||||||
int size,
|
int size,
|
||||||
const std::string& initMethod) {
|
const std::string& initMethod) {
|
||||||
// Parse the init method to extract the host and port
|
|
||||||
if (initMethod.rfind("tcp://", 0) != 0)
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
throw;
|
throw;
|
||||||
auto hostport = initMethod.substr(6);
|
auto hostport = initMethod.substr(6);
|
||||||
@ -89,10 +89,8 @@ inline void bootstrapUniqueId(
|
|||||||
int port = std::stoi(hostport.substr(colon + 1));
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
// create a unique id on the rank 0
|
|
||||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
// create a socket to send the unique id to all other ranks
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
if (sock < 0) {
|
if (sock < 0) {
|
||||||
@ -107,8 +105,6 @@ inline void bootstrapUniqueId(
|
|||||||
serv.sin_port = htons(port);
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
int reuse = 1;
|
int reuse = 1;
|
||||||
// Without this, if I crash or restart your rank-0 process quickly,
|
|
||||||
// the OS might refuse to let you bind to the same port, so reuse
|
|
||||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
@ -236,7 +232,6 @@ void dispatch_dtype(const array& arr, F&& f) {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
// init communication in the constructor (?)
|
|
||||||
class NCCLGroup : public GroupImpl {
|
class NCCLGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
@ -334,6 +329,7 @@ class NCCLGroup : public GroupImpl {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
ncclDataType_t dt,
|
ncclDataType_t dt,
|
||||||
ncclRedOp_t op) {
|
ncclRedOp_t op) {
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
CHECK_NCCL(ncclAllReduce(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<T>(),
|
output.data<T>(),
|
||||||
|
Loading…
Reference in New Issue
Block a user