mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Compare commits
7 Commits
ac76d1ab2d
...
9efabb380c
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9efabb380c | ||
![]() |
76831ed83d | ||
![]() |
e6ae350999 | ||
![]() |
70f2baf39f | ||
![]() |
71a47bc10d | ||
![]() |
e9fbdd20fb | ||
![]() |
f15a127900 |
@ -16,6 +16,9 @@ parameters:
|
|||||||
linux_release:
|
linux_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
cuda_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@ -104,7 +107,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -162,7 +165,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
@ -223,7 +226,6 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
@ -283,7 +285,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
@ -342,7 +344,7 @@ jobs:
|
|||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
@ -356,6 +358,48 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
python_version:
|
||||||
|
type: string
|
||||||
|
default: "3.9"
|
||||||
|
extra_env:
|
||||||
|
type: string
|
||||||
|
default: "DEV_RELEASE=1"
|
||||||
|
machine:
|
||||||
|
image: linux-cuda-12:default
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
python -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
pip install ".[dev]" -v
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build --wheel
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@ -625,3 +669,14 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
cuda_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.cuda_release >>
|
||||||
|
jobs:
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
64
cmake/FindNCCL.cmake
Normal file
64
cmake/FindNCCL.cmake
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Find the nccl libraries
|
||||||
|
#
|
||||||
|
# The following variables are optionally searched for defaults NCCL_ROOT_DIR:
|
||||||
|
# Base directory where all NCCL components are found NCCL_INCLUDE_DIR: Directory
|
||||||
|
# where NCCL header is found NCCL_LIB_DIR: Directory where NCCL library is found
|
||||||
|
#
|
||||||
|
# The following are set after configuration is done: NCCL_FOUND
|
||||||
|
# NCCL_INCLUDE_DIRS NCCL_LIBRARIES
|
||||||
|
#
|
||||||
|
# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks install NCCL
|
||||||
|
# in the same location as the CUDA toolkit. See
|
||||||
|
# https://github.com/caffe2/caffe2/issues/1601
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
|
|||||||
|
|
||||||
conda install conda-forge::mlx
|
conda install conda-forge::mlx
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||||
|
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx-cuda
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@ -65,6 +75,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@ -107,6 +119,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@ -185,6 +199,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=8 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ void CommandEncoder::synchronize() {
|
|||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
worker_.end_batch();
|
||||||
worker_.commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +54,28 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
// Here I assume for now that in is donatable and contiguous.
|
||||||
|
// TODO
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
auto& s = stream();
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace distributed
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@ -97,7 +119,6 @@ NO_GPU_MULTI(CustomKernel)
|
|||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
@ -6,3 +6,4 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@ -111,6 +112,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "nccl") {
|
||||||
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||||
|
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||||
|
endif()
|
382
mlx/distributed/nccl/nccl.cpp
Normal file
382
mlx/distributed/nccl/nccl.cpp
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"CUDA error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_NCCL(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"NCCL error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t sent = send(sock, ptr, len, 0);
|
||||||
|
if (sent <= 0) {
|
||||||
|
perror("send");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += sent;
|
||||||
|
len -= sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recvAll(int sock, void* buf, size_t len) {
|
||||||
|
char* ptr = reinterpret_cast<char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t rec = recv(sock, ptr, len, 0);
|
||||||
|
if (rec <= 0) {
|
||||||
|
perror("recv");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += rec;
|
||||||
|
len -= rec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void bootstrap_unique_id(
|
||||||
|
ncclUniqueId& id,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::string& initMethod) {
|
||||||
|
|
||||||
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
|
throw;
|
||||||
|
auto hostport = initMethod.substr(6);
|
||||||
|
auto colon = hostport.find(':');
|
||||||
|
std::string host = hostport.substr(0, colon);
|
||||||
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
int reuse = 1;
|
||||||
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
if (listen(sock, size - 1) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int peer = 1; peer < size; ++peer) {
|
||||||
|
int conn = accept(sock, nullptr, nullptr);
|
||||||
|
if (conn < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
sendAll(conn, &id, sizeof(id));
|
||||||
|
close(conn);
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Here just wanted to make show that rank 0 has enough time to bind
|
||||||
|
// so we will retry to connect until max attempts
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
hostent* he = gethostbyname(host.c_str());
|
||||||
|
if (!he) {
|
||||||
|
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||||
|
}
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
const int max_retries = 30;
|
||||||
|
int attempt = 0;
|
||||||
|
bool connected = false;
|
||||||
|
|
||||||
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
|
0) {
|
||||||
|
connected = true;
|
||||||
|
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||||
|
<< attempt + 1 << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (errno != ECONNREFUSED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connected) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||||
|
<< " retries: " << strerror(errno);
|
||||||
|
close(sock);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
recvAll(sock, &id, sizeof(id));
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct type_identity {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
switch (arr.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||||
|
case int8:
|
||||||
|
f(type_identity<int8_t>{}, ncclChar);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
f(type_identity<uint8_t>{}, ncclUint8);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
f(type_identity<int32_t>{}, ncclInt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
f(type_identity<uint32_t>{}, ncclUint32);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
f(type_identity<int64_t>{}, ncclInt64);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
f(type_identity<uint64_t>{}, ncclUint64);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
f(type_identity<float16_t>{}, ncclHalf);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
f(type_identity<float>{}, ncclFloat);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
f(type_identity<double>{}, ncclDouble);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
class NCCLGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
|
: rank_(worldRank),
|
||||||
|
size_(worldSize),
|
||||||
|
comm_(nullptr),
|
||||||
|
initMethod_(initMethod) {
|
||||||
|
if (initialized_)
|
||||||
|
return;
|
||||||
|
int ndev;
|
||||||
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
|
CHECK_CUDA(cudaStreamCreate(&stream_));
|
||||||
|
|
||||||
|
detail::bootstrapUniqueId(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLGroup() {
|
||||||
|
ncclCommDestroy(comm_);
|
||||||
|
ncclGroupEnd();
|
||||||
|
cudaStreamDestroy(stream_);
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
if (input.size() != output.size()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
|
}
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[nccl] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
if (input.size() != output.size() / size_) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] Input size must match output size divided by group size.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
if (input.size() == 0) {
|
||||||
|
return; // Nothing to send
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& output, int src, Stream stream) override {
|
||||||
|
if (output.size() == 0) {
|
||||||
|
return; // Nothing to receive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
if (input.size() != output.size()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
|
}
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
if (input.size() != output.size()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
|
}
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclAllReduce(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
stream_));
|
||||||
|
cudaStreamSynchronize(stream_);
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_, size_;
|
||||||
|
std::string initMethod_;
|
||||||
|
ncclUniqueId uniqueId_;
|
||||||
|
ncclComm_t comm_;
|
||||||
|
cudaStream_t stream_;
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||||
|
const char* value = std::getenv(env_var_name);
|
||||||
|
if (value == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
|
<< "' is not set. "
|
||||||
|
<< "Please set it before initializing the distributed backend.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return std::string(value);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||||
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||||
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||||
|
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||||
|
|
||||||
|
int rank = std::stoi(rank_str);
|
||||||
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
std::string init_method = "tcp://" + host + ":" + port;
|
||||||
|
|
||||||
|
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
@ -31,8 +31,7 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
17
python/scripts/repair_cuda.sh
Normal file
17
python/scripts/repair_cuda.sh
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
auditwheel repair dist/* \
|
||||||
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--exclude libcublas* \
|
||||||
|
--exclude libnvrtc*
|
||||||
|
|
||||||
|
cd wheelhouse
|
||||||
|
repaired_wheel=$(find . -name "*.whl" -print -quit)
|
||||||
|
unzip -q "${repaired_wheel}"
|
||||||
|
core_so=$(find mlx -name "core*.so" -print -quit)
|
||||||
|
rpath=$(patchelf --print-rpath "${core_so}")
|
||||||
|
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
|
||||||
|
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
|
||||||
|
|
||||||
|
# Re-zip the repaired wheel
|
||||||
|
zip -r -q "${repaired_wheel}" .
|
8
setup.py
8
setup.py
@ -174,20 +174,26 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
package_dir = {"": "python"}
|
package_dir = {"": "python"}
|
||||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
|
||||||
|
install_requires = []
|
||||||
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
if build_cuda:
|
||||||
|
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx-cuda" if build_cuda else "mlx",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
license="MIT",
|
||||||
url="https://github.com/ml-explore/mlx",
|
url="https://github.com/ml-explore/mlx",
|
||||||
packages=packages,
|
packages=packages,
|
||||||
package_dir=package_dir,
|
package_dir=package_dir,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user