Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a9c720e8cd Improve the ring backend initialization 2025-07-11 15:31:28 -07:00
7 changed files with 194 additions and 209 deletions

View File

@@ -7,6 +7,18 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -289,9 +301,9 @@ jobs:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
build_env: extra_env:
type: string type: string
default: "" default: "DEV_RELEASE=1"
docker: docker:
- image: ubuntu:20.04 - image: ubuntu:20.04
steps: steps:
@@ -320,22 +332,17 @@ jobs:
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.build_env >> pip install . -v << parameters.extra_env >> pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.build_env >> python -m build --wheel << parameters.extra_env >> python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
<< parameters.build_env >> MLX_BUILD_COMMON=1 \ - run:
python -m build --wheel --outdir wheelhouse name: Upload package
- when: command: |
condition: << parameters.build_env >> source env/bin/activate
steps: twine upload wheelhouse/*
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -344,9 +351,9 @@ jobs:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
build_env: extra_env:
type: string type: string
default: "" default: "DEV_RELEASE=1"
machine: machine:
image: linux-cuda-12:default image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
@@ -357,29 +364,25 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.build_env >> \ << parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v pip install ".[dev]" -v
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.build_env >> \ << parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel python -m build --wheel
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh
- when: - run:
condition: << parameters.build_env >> name: Upload package
steps: command: |
- run: source env/bin/activate
name: Upload package twine upload wheelhouse/*.whl
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -391,6 +394,8 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
@@ -404,6 +409,8 @@ workflows:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
filters: filters:
@@ -494,17 +501,7 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@@ -583,21 +580,99 @@ workflows:
- macosx_deployment_target: "15.0" - macosx_deployment_target: "15.0"
xcode_version: "15.0.0" xcode_version: "15.0.0"
python_version: "3.13" python_version: "3.13"
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release: - build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release: - build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -38,16 +38,8 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell .. code-block:: shell
pip install "mlx[cuda]" pip install mlx-cuda
CPU only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install "mlx[cpu]"
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -22,78 +22,20 @@
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/threadpool.h" #include "mlx/threadpool.h"
#ifndef SOL_TCP #ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP #define SOL_TCP IPPROTO_TCP
#endif #endif
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case float64: { \
using T = double; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed::ring { namespace mlx::core::distributed::ring {
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2; constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5; constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000; constexpr const int CONN_WAIT = 1000;
constexpr const int INIT_TIMEOUT = 20000;
using GroupImpl = mlx::core::distributed::detail::GroupImpl; using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json; using json = nlohmann::json;
@@ -503,6 +445,7 @@ std::vector<int> make_connections(
return sockets; return sockets;
} }
template <typename T> template <typename T>
struct SumOp { struct SumOp {
void operator()(const T* input, T* output, size_t N) { void operator()(const T* input, T* output, size_t N) {
@@ -550,19 +493,27 @@ class RingGroup : public GroupImpl {
size_ = nodes.size(); size_ = nodes.size();
int connect_to = (rank_ + 1) % size_; int connect_to = (rank_ + 1) % size_;
// We define the connection order by having the rank_ == size_ - 1 connect // Initialize the ring by making all the connections
// first and accept after. log_info(verbose_, "Rank", rank_, "accepting");
if (rank_ < connect_to) { log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
log_info(verbose_, "Rank", rank_, "accepting"); auto sl = std::async(std::launch::async, accept_connections, nodes[rank_]);
sockets_left_ = std::move(accept_connections(nodes[rank_])); auto sr = std::async(
log_info(verbose_, "Rank", rank_, "connecting to", connect_to); std::launch::async, make_connections, nodes[connect_to], verbose);
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); std::future_status status_sl, status_sr;
} else { for (int i = 0; i < 10; i++) {
log_info(verbose_, "Rank", rank_, "connecting to", connect_to); status_sl = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); status_sr = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
log_info(verbose_, "Rank", rank_, "accepting"); if (status_sl == std::future_status::ready &&
sockets_left_ = std::move(accept_connections(nodes[rank_])); status_sr == std::future_status::ready) {
break;
}
} }
if (status_sl != std::future_status::ready ||
status_sr != std::future_status::ready) {
throw std::runtime_error("[ring] Ring initialization timed out");
}
sockets_left_ = std::move(sl.get());
sockets_right_ = std::move(sr.get());
// Failure if we couldn't make right or left sockets // Failure if we couldn't make right or left sockets
if (sockets_right_.empty()) { if (sockets_right_.empty()) {
@@ -628,18 +579,24 @@ class RingGroup : public GroupImpl {
} }
void all_sum(const array& input, array& output, Stream stream) override { void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE( dispatch_all_types(output.dtype(), [&](auto type_tag) {
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>())); using T = MLX_GET_TYPE(type_tag);
all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>());
});
} }
void all_max(const array& input, array& output, Stream stream) override { void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE( dispatch_all_types(output.dtype(), [&](auto type_tag) {
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>())); using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>());
});
} }
void all_min(const array& input, array& output, Stream stream) override { void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE( dispatch_all_types(output.dtype(), [&](auto type_tag) {
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>())); using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>());
});
} }
std::shared_ptr<GroupImpl> split(int color, int key = -1) override { std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

View File

@@ -3,19 +3,15 @@
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \ --plat manylinux_2_35_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc*
-w wheel_tmp
cd wheelhouse
mkdir wheelhouse
cd wheel_tmp
repaired_wheel=$(find . -name "*.whl" -print -quit) repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}" unzip -q "${repaired_wheel}"
rm "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit) core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}") rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so" patchelf --force-rpath --set-rpath "$rpath" "$core_so"
# Re-zip the repaired wheel # Re-zip the repaired wheel
zip -r -q "../wheelhouse/${repaired_wheel}" . zip -r -q "${repaired_wheel}" .

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx import mlx.core as mx
import mlx_distributed_tests import mlx_distributed_tests
import mlx_tests
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@@ -150,4 +151,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() mlx_tests.MLXTestRunner()

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx import mlx.core as mx
import mlx_distributed_tests import mlx_distributed_tests
import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):

View File

@@ -5,11 +5,10 @@ import os
import platform import platform
import re import re
import subprocess import subprocess
from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run from subprocess import run
from setuptools import Command, Extension, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
@@ -166,28 +165,19 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
# The information here can also be placed in setup.cfg - better separation of # The information here can also be placed in setup.cfg - better separation of
# logic and declaration, and simpler if you include description/version in a file. # logic and declaration, and simpler if you include description/version in a file.
if __name__ == "__main__": if __name__ == "__main__":
packages = find_namespace_packages(
where="python", exclude=["src", "tests", "tests.*"]
)
package_dir = {"": "python"} package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
packages = [
"mlx",
"mlx.nn",
"mlx.nn.layers",
"mlx.optimizers",
]
is_release = "PYPI_RELEASE" in os.environ
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
build_common = "MLX_BUILD_COMMON" in os.environ
install_requires = [] install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda: if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version()
_setup = partial( setup(
setup, name="mlx-cuda" if build_cuda else "mlx",
version=version, version=get_version(),
author="MLX Contributors", author="MLX Contributors",
author_email="mlx@group.apple.com", author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.", description="A framework for machine learning on Apple silicon.",
@@ -195,56 +185,29 @@ if __name__ == "__main__":
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="MIT", license="MIT",
url="https://github.com/ml-explore/mlx", url="https://github.com/ml-explore/mlx",
packages=packages,
package_dir=package_dir, package_dir=package_dir,
package_data=package_data, package_data=package_data,
include_package_data=True,
install_requires=install_requires,
extras_require={
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",
"torch",
"typing_extensions",
],
},
entry_points={
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
},
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
zip_safe=False, zip_safe=False,
python_requires=">=3.9", python_requires=">=3.9",
install_requires=install_requires,
) )
extras = {
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",
"torch",
"typing_extensions",
],
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
if not is_release or build_macos:
_setup(
name="mlx",
include_package_data=True,
packages=packages,
extras_require=extras,
entry_points=entry_points,
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
)
elif build_common:
extras["cpu"] = [f"mlx-cpu=={version}"]
extras["cuda"] = [f"mlx-cuda=={version}"]
_setup(
name="mlx",
packages=["mlx"],
extras_require=extras,
entry_points=entry_points,
exclude_package_data=package_data,
)
else:
_setup(
name="mlx-cuda" if build_cuda else "mlx-cpu",
include_package_data=True,
packages=packages,
extras_require=extras,
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
)