Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a9c720e8cd Improve the ring backend initialization 2025-07-11 15:31:28 -07:00
6 changed files with 69 additions and 156 deletions

View File

@@ -336,11 +336,10 @@ jobs:
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> python -m build --wheel
<< parameters.extra_env >> MLX_BUILD_COMMON=1 python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload packages
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*

View File

@@ -38,16 +38,8 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install "mlx[cuda]"
pip install mlx-cuda
CPU only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install "mlx[cpu]"
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -22,78 +22,20 @@
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP
#endif
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
using T = bool; \
__VA_ARGS__; \
} break; \
case int8: { \
using T = int8_t; \
__VA_ARGS__; \
} break; \
case int16: { \
using T = int16_t; \
__VA_ARGS__; \
} break; \
case int32: { \
using T = int32_t; \
__VA_ARGS__; \
} break; \
case int64: { \
using T = int64_t; \
__VA_ARGS__; \
} break; \
case uint8: { \
using T = uint8_t; \
__VA_ARGS__; \
} break; \
case uint16: { \
using T = uint16_t; \
__VA_ARGS__; \
} break; \
case uint32: { \
using T = uint32_t; \
__VA_ARGS__; \
} break; \
case uint64: { \
using T = uint64_t; \
__VA_ARGS__; \
} break; \
case bfloat16: { \
using T = bfloat16_t; \
__VA_ARGS__; \
} break; \
case float16: { \
using T = float16_t; \
__VA_ARGS__; \
} break; \
case float32: { \
using T = float; \
__VA_ARGS__; \
} break; \
case float64: { \
using T = double; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \
} break; \
}
namespace mlx::core::distributed::ring {
constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
constexpr const size_t ALL_SUM_BUFFERS = 2;
constexpr const int CONN_ATTEMPTS = 5;
constexpr const int CONN_WAIT = 1000;
constexpr const int INIT_TIMEOUT = 20000;
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
using json = nlohmann::json;
@@ -503,6 +445,7 @@ std::vector<int> make_connections(
return sockets;
}
template <typename T>
struct SumOp {
void operator()(const T* input, T* output, size_t N) {
@@ -550,19 +493,27 @@ class RingGroup : public GroupImpl {
size_ = nodes.size();
int connect_to = (rank_ + 1) % size_;
// We define the connection order by having the rank_ == size_ - 1 connect
// first and accept after.
if (rank_ < connect_to) {
log_info(verbose_, "Rank", rank_, "accepting");
sockets_left_ = std::move(accept_connections(nodes[rank_]));
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
} else {
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
sockets_right_ = std::move(make_connections(nodes[connect_to], verbose));
log_info(verbose_, "Rank", rank_, "accepting");
sockets_left_ = std::move(accept_connections(nodes[rank_]));
// Initialize the ring by making all the connections
log_info(verbose_, "Rank", rank_, "accepting");
log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
auto sl = std::async(std::launch::async, accept_connections, nodes[rank_]);
auto sr = std::async(
std::launch::async, make_connections, nodes[connect_to], verbose);
std::future_status status_sl, status_sr;
for (int i = 0; i < 10; i++) {
status_sl = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
status_sr = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10));
if (status_sl == std::future_status::ready &&
status_sr == std::future_status::ready) {
break;
}
}
if (status_sl != std::future_status::ready ||
status_sr != std::future_status::ready) {
throw std::runtime_error("[ring] Ring initialization timed out");
}
sockets_left_ = std::move(sl.get());
sockets_right_ = std::move(sr.get());
// Failure if we couldn't make right or left sockets
if (sockets_right_.empty()) {
@@ -628,18 +579,24 @@ class RingGroup : public GroupImpl {
}
void all_sum(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>());
});
}
void all_max(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>());
});
}
void all_min(const array& input, array& output, Stream stream) override {
SWITCH_TYPE(
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
dispatch_all_types(output.dtype(), [&](auto type_tag) {
using T = MLX_GET_TYPE(type_tag);
all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>());
});
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@@ -150,4 +151,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
if __name__ == "__main__":
unittest.main()
mlx_tests.MLXTestRunner()

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
import mlx_distributed_tests
import mlx_tests
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):

View File

@@ -5,11 +5,10 @@ import os
import platform
import re
import subprocess
from functools import partial
from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, setup
from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.build_ext import build_ext
@@ -166,27 +165,19 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
# The information here can also be placed in setup.cfg - better separation of
# logic and declaration, and simpler if you include description/version in a file.
if __name__ == "__main__":
packages = find_namespace_packages(
where="python", exclude=["src", "tests", "tests.*"]
)
package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
packages = [
"mlx",
"mlx.nn",
"mlx.optimizers",
]
is_release = "PYPI_RELEASE" in os.environ
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
build_common = "MLX_BUILD_COMMON" in os.environ
install_requires = []
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version()
_setup = partial(
setup,
version=version,
setup(
name="mlx-cuda" if build_cuda else "mlx",
version=get_version(),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",
@@ -194,57 +185,29 @@ if __name__ == "__main__":
long_description_content_type="text/markdown",
license="MIT",
url="https://github.com/ml-explore/mlx",
packages=packages,
package_dir=package_dir,
package_data=package_data,
include_package_data=True,
install_requires=install_requires,
extras_require={
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",
"torch",
"typing_extensions",
],
},
entry_points={
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
},
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
zip_safe=False,
python_requires=">=3.9",
install_requires=install_requires,
)
extras = {
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",
"torch",
"typing_extensions",
],
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
test = "-awni-test"
if not is_release or build_macos:
_setup(
name="mlx" + test,
include_package_data=True,
packages=packages,
extras_require=extras,
entry_points=entry_points,
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
)
elif build_common:
extras["cpu"] = [f"mlx-cpu{test}=={version}"]
extras["cuda"] = [f"mlx-cuda{test}=={version}"]
_setup(
name="mlx" + test,
packages=["mlx"],
extras_require=extras,
entry_points=entry_points,
exclude_package_data=package_data,
)
else:
_setup(
name="mlx-cuda" if build_cuda else "mlx-cpu" + test,
include_package_data=True,
packages=packages,
extras_require=extras,
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
)