mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
13 Commits
main
...
cd4b12ce1b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd4b12ce1b | ||
|
|
425043ccca | ||
|
|
95d92af8a0 | ||
|
|
bfdddd644b | ||
|
|
1216afdc91 | ||
|
|
04e94d78bb | ||
|
|
60d4e8b2a8 | ||
|
|
c5745fddd2 | ||
|
|
e937a8033f | ||
|
|
4dfe02d7c6 | ||
|
|
5c2cff9329 | ||
|
|
325dab9559 | ||
|
|
67e454ab0a |
2
.github/actions/build-macos/action.yml
vendored
2
.github/actions/build-macos/action.yml
vendored
@@ -11,7 +11,7 @@ runs:
|
|||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install cmake setuptools nanobind==2.10.2
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
|
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
|
|||||||
22
.github/actions/setup-linux/action.yml
vendored
22
.github/actions/setup-linux/action.yml
vendored
@@ -10,29 +10,23 @@ inputs:
|
|||||||
description: 'Version of python to set up'
|
description: 'Version of python to set up'
|
||||||
required: false
|
required: false
|
||||||
default: '3.10'
|
default: '3.10'
|
||||||
use-ccache:
|
|
||||||
description: 'Whether to enable ccache'
|
|
||||||
required: false
|
|
||||||
default: 'true'
|
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
|
- name: Use ccache
|
||||||
|
if: ${{ runner.arch == 'x86_64' }}
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||||
|
max-size: 1GB
|
||||||
|
|
||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||||
|
|
||||||
- name: Use ccache
|
|
||||||
if: ${{ inputs.use-ccache == 'true' }}
|
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
|
||||||
with:
|
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
|
|
||||||
max-size: 1GB
|
|
||||||
# ccache-action bug: running "apt-get update" fails on large arm runner.
|
|
||||||
update-package-index: false
|
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
@@ -42,7 +36,7 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
python -m venv .venv
|
python -m venv .venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pip install setuptools cmake nanobind==2.10.2
|
pip install setuptools cmake nanobind==2.4.0
|
||||||
echo PATH=$PATH >> $GITHUB_ENV
|
echo PATH=$PATH >> $GITHUB_ENV
|
||||||
# Make cmake search .venv for nanobind
|
# Make cmake search .venv for nanobind
|
||||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||||
|
|||||||
6
.github/workflows/nightly.yml
vendored
6
.github/workflows/nightly.yml
vendored
@@ -23,14 +23,14 @@ jobs:
|
|||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
arch: "x86_64"
|
arch: "x86_64"
|
||||||
- name: Upload mlx artifacts
|
- name: Upload mlx artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: linux-wheels-${{ matrix.python_version }}
|
name: linux-wheels-${{ matrix.python_version }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: Upload mlx-cpu artifacts
|
- name: Upload mlx-cpu artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: mlx-cpu
|
name: mlx-cpu
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
toolkit: 'cuda-12.9'
|
toolkit: 'cuda-12.9'
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
|||||||
24
.github/workflows/release.yml
vendored
24
.github/workflows/release.yml
vendored
@@ -57,20 +57,19 @@ jobs:
|
|||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
use-ccache: false
|
|
||||||
- uses: ./.github/actions/build-linux-release
|
- uses: ./.github/actions/build-linux-release
|
||||||
with:
|
with:
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
arch: ${{ matrix.arch }}
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
- name: Upload CPU artifacts
|
- name: Upload CPU artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-cpu-${{ matrix.arch }}
|
name: mlx-cpu-${{ matrix.arch }}
|
||||||
@@ -96,7 +95,7 @@ jobs:
|
|||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install cmake setuptools nanobind==2.10.2
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
pip install -e . -v
|
pip install -e . -v
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
@@ -114,14 +113,14 @@ jobs:
|
|||||||
macos-target: 15.0
|
macos-target: 15.0
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mac-wheels-${{ matrix.python-version }}
|
name: mac-wheels-${{ matrix.python-version }}
|
||||||
path: dist/mlx-*.whl
|
path: dist/mlx-*.whl
|
||||||
- name: Upload Metal artifacts
|
- name: Upload Metal artifacts
|
||||||
if: matrix.python-version == '3.10'
|
if: matrix.python-version == '3.10'
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-metal
|
name: mlx-metal
|
||||||
@@ -142,13 +141,12 @@ jobs:
|
|||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: ${{ matrix.toolkit }}
|
toolkit: ${{ matrix.toolkit }}
|
||||||
use-ccache: false
|
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
with:
|
||||||
arch: ${{ matrix.arch }}
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
@@ -164,12 +162,12 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx
|
url: https://pypi.org/p/mlx
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v7
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
pattern: linux-wheels-*
|
pattern: linux-wheels-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
path: dist
|
path: dist
|
||||||
- uses: actions/download-artifact@v7
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
pattern: mac-wheels-*
|
pattern: mac-wheels-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
@@ -191,7 +189,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cuda
|
url: https://pypi.org/p/mlx-cuda
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v7
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: dist
|
path: dist
|
||||||
@@ -212,7 +210,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cpu
|
url: https://pypi.org/p/mlx-cpu
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v7
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
pattern: mlx-cpu-*
|
pattern: mlx-cpu-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
@@ -234,7 +232,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-metal
|
url: https://pypi.org/p/mlx-metal
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v7
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-metal
|
name: mlx-metal
|
||||||
path: dist
|
path: dist
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ if(MLX_BUILD_METAL)
|
|||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
execute_process(
|
||||||
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-path"
|
||||||
|
OUTPUT_VARIABLE CMAKE_OSX_SYSROOT
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
@@ -273,7 +277,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.10
|
Python 3.8
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
REQUIRED)
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||||
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||||
|
|
||||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||||
|
|||||||
@@ -3,6 +3,6 @@ requires = [
|
|||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25",
|
||||||
"mlx>=0.18.0",
|
"mlx>=0.18.0",
|
||||||
"nanobind==2.10.2",
|
"nanobind==2.4.0",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.10.2
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() && !is_constant(i)) {
|
in.is_donatable() && is_constant(i)) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
!is_constant(i)) {
|
is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
|
|||||||
@@ -291,17 +291,6 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_keys,
|
num_keys,
|
||||||
kshape = keys.shape(),
|
kshape = keys.shape(),
|
||||||
kstrides = keys.strides()]() mutable {
|
kstrides = keys.strides()]() mutable {
|
||||||
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
|
|
||||||
if (4 * loc + 4 <= bytes_per_key) {
|
|
||||||
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
|
|
||||||
} else {
|
|
||||||
std::copy(
|
|
||||||
reinterpret_cast<char*>(&v),
|
|
||||||
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
|
|
||||||
cptr + 4 * loc);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||||
auto half_size = out_skip / 2;
|
auto half_size = out_skip / 2;
|
||||||
bool even = out_skip % 2 == 0;
|
bool even = out_skip % 2 == 0;
|
||||||
@@ -321,12 +310,18 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (count.first < half_size) {
|
if (count.first < half_size) {
|
||||||
auto rb = random::threefry2x32_hash(key, count);
|
auto rb = random::threefry2x32_hash(key, count);
|
||||||
ptr[count.first++] = rb.first;
|
ptr[count.first++] = rb.first;
|
||||||
copy_remaining(cptr, count.second, rb.second);
|
if (bytes_per_key % 4 > 0) {
|
||||||
|
std::copy(
|
||||||
|
reinterpret_cast<char*>(&rb.second),
|
||||||
|
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
||||||
|
cptr + 4 * count.second);
|
||||||
|
} else {
|
||||||
|
ptr[count.second] = rb.second;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!even) {
|
if (!even) {
|
||||||
count.second = 0;
|
count.second = 0;
|
||||||
copy_remaining(
|
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
||||||
cptr, half_size, random::threefry2x32_hash(key, count).first);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -3,9 +3,5 @@
|
|||||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
#if defined(__x86_64__)
|
|
||||||
// the accelerate_simd implementation require neon -- use base implementation
|
|
||||||
#else
|
|
||||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|||||||
@@ -338,23 +338,18 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
|||||||
}
|
}
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
switch (type) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
case cudaGraphNodeTypeGraph: {
|
|
||||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
cudaGraph_t child;
|
cudaGraph_t child;
|
||||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||||
is_updatable &= sub_is_updatable;
|
is_updatable &= sub_is_updatable;
|
||||||
key += subkey;
|
key += subkey;
|
||||||
break;
|
} else if (type == cudaGraphNodeTypeMemset) {
|
||||||
}
|
|
||||||
case cudaGraphNodeTypeHost:
|
|
||||||
key += "H";
|
|
||||||
break;
|
|
||||||
case cudaGraphNodeTypeMemset:
|
|
||||||
key += "M";
|
key += "M";
|
||||||
break;
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
case cudaGraphNodeTypeKernel: {
|
is_updatable = false;
|
||||||
|
} else {
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
@@ -365,16 +360,6 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
|||||||
key += "K";
|
key += "K";
|
||||||
key += std::to_string(cluster_dim.clusterDim.x);
|
key += std::to_string(cluster_dim.clusterDim.x);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
|
||||||
case cudaGraphNodeTypeWaitEvent:
|
|
||||||
key += "W";
|
|
||||||
break;
|
|
||||||
case cudaGraphNodeTypeEventRecord:
|
|
||||||
key += "R";
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
is_updatable = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
key += ")";
|
key += ")";
|
||||||
|
|||||||
@@ -2,11 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
|
|
||||||
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
|
|
||||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
|
||||||
#include "mlx/backend/cuda/vector_types.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@@ -17,6 +13,17 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
template <int bits>
|
||||||
|
struct Quantize {
|
||||||
|
__device__ uint8_t operator()(float x) {
|
||||||
|
if constexpr (bits == 8) {
|
||||||
|
return __nv_fp8_e4m3(x).__x;
|
||||||
|
} else {
|
||||||
|
return __nv_fp4_e2m1(x).__x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <int bits>
|
template <int bits>
|
||||||
struct Dequantize {
|
struct Dequantize {
|
||||||
__device__ float operator()(uint8_t x) {
|
__device__ float operator()(uint8_t x) {
|
||||||
@@ -30,40 +37,29 @@ struct Dequantize {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||||
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
__global__ void
|
||||||
using Tx2 = Vector2_t<T>;
|
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||||
using Tx4 = Vector4_t<T>;
|
|
||||||
uint32_t rbits = 0; // reserved bits for future use
|
|
||||||
auto block_size = cg::this_thread_block().dim_threads();
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
auto block_idx = cg::this_thread_block().group_index();
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
|
|
||||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||||
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
|
|
||||||
|
|
||||||
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
|
auto grid_dim_x =
|
||||||
size_t base_idx = thread_idx * group_size;
|
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||||
|
size_t index = tidx + grid_dim_x * size_t(tidy);
|
||||||
if (base_idx >= size) {
|
if (index >= size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
float w_thread = w[index];
|
||||||
float scale = 0.0f;
|
|
||||||
|
|
||||||
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
cg::greater<float> max_op;
|
||||||
|
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < group_size; i += 2) {
|
|
||||||
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
|
||||||
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
|
|
||||||
}
|
|
||||||
|
|
||||||
scale = static_cast<float>(
|
|
||||||
max(fabsf(static_cast<float>(amax_2x.x)),
|
|
||||||
fabsf(static_cast<float>(amax_2x.y))));
|
|
||||||
|
|
||||||
|
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
||||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||||
// Convert to mx scale or nv scale
|
// Convert to mx scale or nv scale
|
||||||
using ScaleType =
|
using ScaleType =
|
||||||
@@ -72,24 +68,21 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|||||||
uint8_t q_scale = s.__x;
|
uint8_t q_scale = s.__x;
|
||||||
scale = float(s);
|
scale = float(s);
|
||||||
|
|
||||||
scales[thread_idx] = q_scale;
|
// Write out the scales
|
||||||
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
|
size_t gindex = index / group_size;
|
||||||
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
|
if (index % group_size == 0) {
|
||||||
|
scales[gindex] = q_scale;
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
||||||
for (int i = 0; i < group_size / 4; i++) {
|
if (bits == 4) {
|
||||||
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
|
uint8_t sval = warp.shfl_down(output, 1);
|
||||||
if constexpr (bits == 8) {
|
output |= sval << bits;
|
||||||
uint32_t quantized_val =
|
|
||||||
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
|
||||||
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
|
|
||||||
} else {
|
|
||||||
uint16_t quantized_val =
|
|
||||||
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
|
||||||
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
|
|
||||||
}
|
}
|
||||||
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
||||||
|
if (index % pack_factor == 0) {
|
||||||
|
out[index / pack_factor] = output;
|
||||||
}
|
}
|
||||||
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||||
@@ -149,16 +142,15 @@ void fp_quantize(
|
|||||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if constexpr (!std::is_same_v<T, double>) {
|
if constexpr (!std::is_same_v<T, double>) {
|
||||||
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
|
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
||||||
if (bits == 8) {
|
if (bits == 8) {
|
||||||
kernel = cu::fp_quantize<T, 32, 8, true, false>;
|
kernel = cu::fp_quantize<T, 32, 8, true>;
|
||||||
} else if (group_size == 16) {
|
} else if (group_size == 16) {
|
||||||
kernel = cu::fp_quantize<T, 16, 4, false, false>;
|
kernel = cu::fp_quantize<T, 16, 4, false>;
|
||||||
}
|
}
|
||||||
bool large = w.size() > UINT_MAX;
|
bool large = w.size() > UINT_MAX;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
|
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
||||||
|
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_fp8.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include "mlx/backend/cuda/vector_types.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
// TODO implement fast path
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ uint32_t
|
|
||||||
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
|
|
||||||
uint32_t out_fp8x4 = 0;
|
|
||||||
float4 scaled;
|
|
||||||
scaled.x = static_cast<float>(input.x) * scale;
|
|
||||||
scaled.y = static_cast<float>(input.y) * scale;
|
|
||||||
scaled.z = static_cast<float>(input.z) * scale;
|
|
||||||
scaled.w = static_cast<float>(input.w) * scale;
|
|
||||||
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
|
|
||||||
return out_fp8x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Place holder for future fast path implementation
|
|
||||||
template <typename T, bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
|
|
||||||
const Vector4_t<T> input,
|
|
||||||
const float scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -1,334 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_fp4.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include "mlx/backend/cuda/vector_types.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
|
||||||
using fp16x4 = Vector4_t<__half>;
|
|
||||||
using f32x4 = Vector4_t<float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ uint16_t
|
|
||||||
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
|
|
||||||
// Fallback implementation for architectures that do not support cvt
|
|
||||||
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
fp32x4 scaled;
|
|
||||||
scaled.x = static_cast<float>(input.x) * scale;
|
|
||||||
scaled.y = static_cast<float>(input.y) * scale;
|
|
||||||
scaled.z = static_cast<float>(input.z) * scale;
|
|
||||||
scaled.w = static_cast<float>(input.w) * scale;
|
|
||||||
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
|
|
||||||
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
|
|
||||||
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
|
|
||||||
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
|
|
||||||
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
|
|
||||||
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
|
|
||||||
static_cast<uint16_t>(q0);
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
|
||||||
defined(__CUDA_ARCH_SPECIFIC__)
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t
|
|
||||||
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b16 x0_bf16; \n\t" // first bf16
|
|
||||||
".reg.b16 x1_bf16; \n\t" // second bf16
|
|
||||||
".reg.b16 x2_bf16; \n\t" // third bf16
|
|
||||||
".reg.b16 x3_bf16; \n\t" // fourth bf16
|
|
||||||
".reg.b32 x0; \n\t" // to hold scaled first
|
|
||||||
".reg.b32 x1; \n\t" // to hold scaled second
|
|
||||||
".reg.b32 x2; \n\t" // to hold scaled third
|
|
||||||
".reg.b32 x3; \n\t" // to hold scaled fourth
|
|
||||||
".reg.b64 x01; \n\t" // to hold vector mul
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
|
|
||||||
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
|
|
||||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
|
|
||||||
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
|
|
||||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
|
||||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
|
||||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
|
||||||
"mov.b64 x01, {x0, x1}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
|
|
||||||
"mov.b64 x23, {x2, x3}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
|
|
||||||
// pair
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
|
|
||||||
// pair
|
|
||||||
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(
|
|
||||||
scale))); // here cast is needed becuase an asm operand must have
|
|
||||||
// scalar type
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
|
|
||||||
const bf16x4 input_bf16x4,
|
|
||||||
const float2 scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b16 x0_bf16; \n\t"
|
|
||||||
".reg.b16 x1_bf16; \n\t"
|
|
||||||
".reg.b16 x2_bf16; \n\t"
|
|
||||||
".reg.b16 x3_bf16; \n\t"
|
|
||||||
".reg.b32 x0; \n\t"
|
|
||||||
".reg.b32 x1; \n\t"
|
|
||||||
".reg.b32 x2; \n\t"
|
|
||||||
".reg.b32 x3; \n\t"
|
|
||||||
".reg.b64 x01; \n\t"
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b16 q0; \n\t"
|
|
||||||
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
|
|
||||||
"cvt.f32.bf16 x0, x0_bf16; \n\t"
|
|
||||||
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
|
||||||
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
|
||||||
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
|
||||||
"mov.b64 x01, {x0, x1}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %2; \n\t"
|
|
||||||
"mov.b64 x23, {x2, x3}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %2; \n\t"
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
||||||
"r"(rbits));
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
|
|
||||||
const float2 input_fp32x2_0,
|
|
||||||
const float2 input_fp32x2_1,
|
|
||||||
const float2 scale) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b32 x0; \n\t"
|
|
||||||
".reg.b32 x1; \n\t"
|
|
||||||
".reg.b32 x2; \n\t"
|
|
||||||
".reg.b32 x3; \n\t"
|
|
||||||
".reg.b64 x01; \n\t"
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b8 q0; \n\t"
|
|
||||||
".reg.b8 q1; \n\t"
|
|
||||||
"mov.b64 x01, {%1, %2}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %5; \n\t"
|
|
||||||
"mov.b64 x23, {%3, %4}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %5; \n\t"
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
|
||||||
"mov.b16 %0, {q0, q1}; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "f"(input_fp32x2_0.x),
|
|
||||||
"f"(input_fp32x2_0.y),
|
|
||||||
"f"(input_fp32x2_1.x),
|
|
||||||
"f"(input_fp32x2_1.y),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
|
|
||||||
const float2 input_fp32x2_0,
|
|
||||||
const float2 input_fp32x2_1,
|
|
||||||
const float2 scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b32 x0; \n\t"
|
|
||||||
".reg.b32 x1; \n\t"
|
|
||||||
".reg.b32 x2; \n\t"
|
|
||||||
".reg.b32 x3; \n\t"
|
|
||||||
".reg.b64 x01; \n\t"
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b16 q0; \n\t"
|
|
||||||
"mov.b64 x01, {%1, %2}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %5; \n\t"
|
|
||||||
"mov.b64 x23, {%3, %4}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %5; \n\t"
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "f"(input_fp32x2_0.x),
|
|
||||||
"f"(input_fp32x2_0.y),
|
|
||||||
"f"(input_fp32x2_1.x),
|
|
||||||
"f"(input_fp32x2_1.y),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
||||||
"r"(rbits));
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t
|
|
||||||
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b16 x0_fp16; \n\t"
|
|
||||||
".reg.b16 x1_fp16; \n\t"
|
|
||||||
".reg.b16 x2_fp16; \n\t"
|
|
||||||
".reg.b16 x3_fp16; \n\t"
|
|
||||||
".reg.b32 x0; \n\t"
|
|
||||||
".reg.b32 x1; \n\t"
|
|
||||||
".reg.b32 x2; \n\t"
|
|
||||||
".reg.b32 x3; \n\t"
|
|
||||||
".reg.b64 x01; \n\t"
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b8 q0; \n\t"
|
|
||||||
".reg.b8 q1; \n\t"
|
|
||||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
|
||||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
|
||||||
"mov.b64 x01, {x0, x1}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %2; \n\t"
|
|
||||||
"mov.b64 x23, {x2, x3}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %2; \n\t"
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
|
||||||
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
|
||||||
"mov.b16 %0, {q0, q1}; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
|
|
||||||
const fp16x4 input_fp16x4,
|
|
||||||
const float2 scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
uint16_t out_fp4x4 = 0;
|
|
||||||
asm volatile(
|
|
||||||
"{\n"
|
|
||||||
".reg.b16 x0_fp16; \n\t"
|
|
||||||
".reg.b16 x1_fp16; \n\t"
|
|
||||||
".reg.b16 x2_fp16; \n\t"
|
|
||||||
".reg.b16 x3_fp16; \n\t"
|
|
||||||
".reg.b32 x0; \n\t"
|
|
||||||
".reg.b32 x1; \n\t"
|
|
||||||
".reg.b32 x2; \n\t"
|
|
||||||
".reg.b32 x3; \n\t"
|
|
||||||
".reg.b64 x01; \n\t"
|
|
||||||
".reg.b64 x23; \n\t"
|
|
||||||
".reg.b16 q0; \n\t"
|
|
||||||
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
|
||||||
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
|
||||||
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
|
||||||
"mov.b64 x01, {x0, x1}; \n\t"
|
|
||||||
"mul.f32x2 x01, x01, %2; \n\t"
|
|
||||||
"mov.b64 x23, {x2, x3}; \n\t"
|
|
||||||
"mul.f32x2 x23, x23, %2; \n\t"
|
|
||||||
"mov.b64 {x0, x1}, x01; \n\t"
|
|
||||||
"mov.b64 {x2, x3}, x23; \n\t"
|
|
||||||
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
|
||||||
"}"
|
|
||||||
: "=h"(out_fp4x4)
|
|
||||||
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
|
||||||
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
||||||
"r"(rbits));
|
|
||||||
return out_fp4x4;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
|
|
||||||
const bf16x4 input,
|
|
||||||
const float scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
||||||
if constexpr (USE_SR) {
|
|
||||||
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
|
||||||
} else {
|
|
||||||
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
|
|
||||||
const fp16x4 input,
|
|
||||||
const float scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
||||||
if constexpr (USE_SR) {
|
|
||||||
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
|
||||||
} else {
|
|
||||||
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint16_t
|
|
||||||
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
|
|
||||||
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
||||||
float2 input_fp32x2_0 = make_float2(input.x, input.y);
|
|
||||||
float2 input_fp32x2_1 = make_float2(input.z, input.w);
|
|
||||||
|
|
||||||
if constexpr (USE_SR) {
|
|
||||||
return scale_cvt_fp32x4_to_fp4x4_rs(
|
|
||||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
|
|
||||||
} else {
|
|
||||||
return scale_cvt_fp32x4_to_fp4x4_rn(
|
|
||||||
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
|
|
||||||
const Vector4_t<T> input,
|
|
||||||
const float scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
|
||||||
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
||||||
} else if constexpr (std::is_same<T, __half>::value) {
|
|
||||||
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
||||||
} else {
|
|
||||||
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
|
|
||||||
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
|
||||||
|
|
||||||
template <typename T, bool USE_SR>
|
|
||||||
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
|
|
||||||
const Vector4_t<T> input,
|
|
||||||
const float scale,
|
|
||||||
uint32_t rbits) {
|
|
||||||
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
|
||||||
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
|
||||||
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
|
|
||||||
#else
|
|
||||||
static_assert(
|
|
||||||
!USE_SR,
|
|
||||||
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
|
|
||||||
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -15,22 +15,6 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
|||||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
|
|
||||||
if constexpr (
|
|
||||||
(std::is_same<T, __nv_bfloat162>::value) ||
|
|
||||||
(std::is_same<T, __half2>::value)) {
|
|
||||||
T a = x1;
|
|
||||||
T b = x2;
|
|
||||||
out = __hmax2(__habs2(a), __habs2(b));
|
|
||||||
} else if constexpr (std::is_same<T, float2>::value) {
|
|
||||||
float2 a = x1;
|
|
||||||
float2 b = x2;
|
|
||||||
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
|
|
||||||
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
|
|||||||
@@ -3,10 +3,31 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/steel/utils.cuh"
|
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||||
#include "mlx/backend/cuda/vector_types.cuh"
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||||
|
template <typename T>
|
||||||
|
struct Vector2;
|
||||||
|
template <>
|
||||||
|
struct Vector2<double> {
|
||||||
|
using type = double2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<float> {
|
||||||
|
using type = float2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__half> {
|
||||||
|
using type = __half2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__nv_bfloat16> {
|
||||||
|
using type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using Vector2_t = typename Vector2<T>::type;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||||
* the warp.
|
* the warp.
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Vector2;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Vector2<double> {
|
|
||||||
using type = double2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Vector2<float> {
|
|
||||||
using type = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Vector2<__half> {
|
|
||||||
using type = __half2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Vector2<__nv_bfloat16> {
|
|
||||||
using type = __nv_bfloat162;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using Vector2_t = typename Vector2<T>::type;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Vector4 {
|
|
||||||
T x, y, z, w;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
using Vector4_t = Vector4<T>;
|
|
||||||
|
|
||||||
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
|
||||||
using fp16x4 = Vector4_t<__half>;
|
|
||||||
using fp32x4 = Vector4_t<float>;
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -347,7 +347,7 @@ template <
|
|||||||
MMAFrag_mask_t::load_safe(
|
MMAFrag_mask_t::load_safe(
|
||||||
mfrag,
|
mfrag,
|
||||||
mask,
|
mask,
|
||||||
int64_t(mask_params->M_strides[2]),
|
int(mask_params->M_strides[2]),
|
||||||
Int<1>{},
|
Int<1>{},
|
||||||
params->qL,
|
params->qL,
|
||||||
params->kL,
|
params->kL,
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ template <
|
|||||||
MSubTile mfrag;
|
MSubTile mfrag;
|
||||||
mfrag.load_safe(
|
mfrag.load_safe(
|
||||||
mask,
|
mask,
|
||||||
int64_t(mask_params->M_strides[2]),
|
int(mask_params->M_strides[2]),
|
||||||
Int<1>{},
|
Int<1>{},
|
||||||
params->qL,
|
params->qL,
|
||||||
params->kL,
|
params->kL,
|
||||||
|
|||||||
@@ -105,20 +105,17 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
LimY lim_y,
|
LimY lim_y,
|
||||||
OffX off_x = Int<0>{},
|
OffX off_x = Int<0>{},
|
||||||
OffY off_y = Int<0>{}) {
|
OffY off_y = Int<0>{}) {
|
||||||
src += off_x * str_x + off_y * str_y;
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kElemRows; i++) {
|
for (short i = 0; i < kElemRows; i++) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short j = 0; j < kElemCols; j++) {
|
for (short j = 0; j < kElemCols; j++) {
|
||||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||||
dst[i * kElemCols + j] = static_cast<T>(src[0]);
|
dst[i * kElemCols + j] =
|
||||||
|
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
||||||
} else {
|
} else {
|
||||||
dst[i * kElemCols + j] = T(0);
|
dst[i * kElemCols + j] = T(0);
|
||||||
}
|
}
|
||||||
src += str_y;
|
|
||||||
}
|
}
|
||||||
src -= kElemCols * str_y;
|
|
||||||
src += str_x;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CPU AND NOT WIN32)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
@@ -102,7 +103,27 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available() ||
|
||||||
|
jaccl::is_available();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_available(const std::string& bk) {
|
||||||
|
if (bk == "any") {
|
||||||
|
return is_available();
|
||||||
|
}
|
||||||
|
if (bk == "mpi") {
|
||||||
|
return mpi::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "ring") {
|
||||||
|
return ring::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "nccl") {
|
||||||
|
return nccl::is_available();
|
||||||
|
}
|
||||||
|
if (bk == "jaccl") {
|
||||||
|
return jaccl::is_available();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -135,6 +156,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
} else if (bk == "nccl") {
|
||||||
group = nccl::init(strict);
|
group = nccl::init(strict);
|
||||||
|
} else if (bk == "jaccl") {
|
||||||
|
group = jaccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
if (mlx::core::cu::is_available()) {
|
if (mlx::core::cu::is_available()) {
|
||||||
group = nccl::init(false);
|
group = nccl::init(false);
|
||||||
@@ -148,13 +171,17 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(false);
|
group = mpi::init(false);
|
||||||
bk_ = "mpi";
|
bk_ = "mpi";
|
||||||
}
|
}
|
||||||
|
if (group == nullptr) {
|
||||||
|
group = jaccl::init(false);
|
||||||
|
bk_ = "jaccl";
|
||||||
|
}
|
||||||
if (group == nullptr && strict) {
|
if (group == nullptr && strict) {
|
||||||
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
throw std::runtime_error("[distributed] Couldn't initialize any backend");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[distributed] The only valid values for backend are 'any', 'mpi' "
|
msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
|
||||||
<< "and 'ring' but '" << bk << "' was provided.";
|
<< "'jaccl' and 'ring' but '" << bk << "' was provided.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class GroupImpl;
|
|||||||
|
|
||||||
/* Check if a communication backend is available */
|
/* Check if a communication backend is available */
|
||||||
bool is_available();
|
bool is_available();
|
||||||
|
bool is_available(const std::string& bk);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A distributed::Group represents a group of independent mlx processes that
|
* A distributed::Group represents a group of independent mlx processes that
|
||||||
|
|||||||
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
8
mlx/distributed/jaccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CPU
|
||||||
|
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
|
||||||
|
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
|
||||||
|
target_link_libraries(mlx PRIVATE rdma)
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
|
||||||
|
endif()
|
||||||
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
1123
mlx/distributed/jaccl/jaccl.cpp
Normal file
File diff suppressed because it is too large
Load Diff
12
mlx/distributed/jaccl/jaccl.h
Normal file
12
mlx/distributed/jaccl/jaccl.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::jaccl
|
||||||
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
20
mlx/distributed/jaccl/no_jaccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/jaccl/jaccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::jaccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize jaccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::jaccl
|
||||||
38
mlx/distributed/reduction_ops.h
Normal file
38
mlx/distributed/reduction_ops.h
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SumOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output += *input;
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MaxOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::max(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MinOp {
|
||||||
|
void operator()(const T* input, T* output, size_t N) const {
|
||||||
|
while (N-- > 0) {
|
||||||
|
*output = std::min(*output, *input);
|
||||||
|
input++;
|
||||||
|
output++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <arpa/inet.h>
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@@ -22,6 +19,8 @@
|
|||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/distributed/reduction_ops.h"
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
#include "mlx/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
#ifndef SOL_TCP
|
#ifndef SOL_TCP
|
||||||
@@ -94,6 +93,7 @@ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
|
|||||||
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
constexpr const size_t ALL_SUM_BUFFERS = 2;
|
||||||
constexpr const int CONN_ATTEMPTS = 5;
|
constexpr const int CONN_ATTEMPTS = 5;
|
||||||
constexpr const int CONN_WAIT = 1000;
|
constexpr const int CONN_WAIT = 1000;
|
||||||
|
constexpr const char* RING_TAG = "[ring]";
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@@ -296,55 +296,6 @@ class CommunicationThreads {
|
|||||||
std::unordered_map<int, SocketThread> threads_;
|
std::unordered_map<int, SocketThread> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct address_t {
|
|
||||||
sockaddr_storage addr;
|
|
||||||
socklen_t len;
|
|
||||||
|
|
||||||
const sockaddr* get() const {
|
|
||||||
return (struct sockaddr*)&addr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr from an ip and port provided as strings.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip, const std::string& port) {
|
|
||||||
struct addrinfo hints, *res;
|
|
||||||
memset(&hints, 0, sizeof(hints));
|
|
||||||
hints.ai_family = AF_UNSPEC;
|
|
||||||
hints.ai_socktype = SOCK_STREAM;
|
|
||||||
|
|
||||||
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
|
||||||
if (status != 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip << ":" << port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
address_t result;
|
|
||||||
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
|
||||||
result.len = res->ai_addrlen;
|
|
||||||
freeaddrinfo(res);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse a sockaddr provided as an <ip>:<port> string.
|
|
||||||
*/
|
|
||||||
address_t parse_address(const std::string& ip_port) {
|
|
||||||
auto colon = ip_port.find(":");
|
|
||||||
if (colon == std::string::npos) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Can't parse address " << ip_port;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
|
||||||
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
|
||||||
|
|
||||||
return parse_address(ip, port);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load all addresses from the json hostfile. The hostfile is a list of
|
* Load all addresses from the json hostfile. The hostfile is a list of
|
||||||
* addresses in order of rank. For each rank there can be many addresses so
|
* addresses in order of rank. For each rank there can be many addresses so
|
||||||
@@ -357,15 +308,15 @@ address_t parse_address(const std::string& ip_port) {
|
|||||||
* ["ip3:5000", "ip3:5001"],
|
* ["ip3:5000", "ip3:5001"],
|
||||||
* ]
|
* ]
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
|
||||||
std::vector<std::vector<address_t>> nodes;
|
std::vector<std::vector<detail::address_t>> nodes;
|
||||||
std::ifstream f(hostfile);
|
std::ifstream f(hostfile);
|
||||||
|
|
||||||
json hosts = json::parse(f);
|
json hosts = json::parse(f);
|
||||||
for (auto& h : hosts) {
|
for (auto& h : hosts) {
|
||||||
std::vector<address_t> host;
|
std::vector<detail::address_t> host;
|
||||||
for (auto& ips : h) {
|
for (auto& ips : h) {
|
||||||
host.push_back(parse_address(ips.get<std::string>()));
|
host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
|
||||||
}
|
}
|
||||||
nodes.push_back(std::move(host));
|
nodes.push_back(std::move(host));
|
||||||
}
|
}
|
||||||
@@ -377,73 +328,15 @@ std::vector<std::vector<address_t>> load_nodes(const char* hostfile) {
|
|||||||
* Create a socket and accept one connection for each of the provided
|
* Create a socket and accept one connection for each of the provided
|
||||||
* addresses.
|
* addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
std::vector<int> accept_connections(
|
||||||
|
const std::vector<detail::address_t>& addresses) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
// Create the socket to wait for connections from the peers
|
detail::TCPSocket socket(RING_TAG);
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
socket.listen(RING_TAG, address);
|
||||||
if (sock < 0) {
|
sockets.push_back(socket.accept(RING_TAG).detach());
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we can launch immediately after shutdown by setting the
|
|
||||||
// reuseaddr option so that we don't get address already in use errors
|
|
||||||
int enable = 1;
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseaddr (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
success = setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't enable reuseport (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the address and port
|
|
||||||
success = bind(sock, address.get(), address.len);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't bind socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections
|
|
||||||
success = listen(sock, 0);
|
|
||||||
if (success < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't listen (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
int peer_socket = accept(sock, nullptr, nullptr);
|
|
||||||
if (peer_socket < 0) {
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Accept failed (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the listening socket
|
|
||||||
shutdown(sock, 2);
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
sockets.push_back(peer_socket);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
@@ -454,93 +347,42 @@ std::vector<int> accept_connections(const std::vector<address_t>& addresses) {
|
|||||||
* provided addresses.
|
* provided addresses.
|
||||||
*/
|
*/
|
||||||
std::vector<int> make_connections(
|
std::vector<int> make_connections(
|
||||||
const std::vector<address_t>& addresses,
|
const std::vector<detail::address_t>& addresses,
|
||||||
bool verbose) {
|
bool verbose) {
|
||||||
std::vector<int> sockets;
|
std::vector<int> sockets;
|
||||||
int success;
|
int success;
|
||||||
|
|
||||||
for (auto& address : addresses) {
|
for (auto& address : addresses) {
|
||||||
int sock;
|
sockets.push_back(detail::TCPSocket::connect(
|
||||||
|
RING_TAG,
|
||||||
// Attempt to connect to the peer CONN_ATTEMPTS times with exponential
|
address,
|
||||||
// backoff. TODO: Do we need that?
|
CONN_ATTEMPTS,
|
||||||
for (int attempt = 0; attempt < CONN_ATTEMPTS; attempt++) {
|
CONN_WAIT,
|
||||||
// Create the socket
|
[verbose](int attempt, int wait) {
|
||||||
sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (attempt > 0) {
|
|
||||||
int wait = (1 << (attempt - 1)) * CONN_WAIT;
|
|
||||||
log_info(
|
log_info(
|
||||||
verbose,
|
verbose,
|
||||||
"Attempt",
|
"Attempt",
|
||||||
attempt,
|
attempt,
|
||||||
"wait",
|
"waiting",
|
||||||
wait,
|
wait,
|
||||||
"ms (error:",
|
"ms (error:",
|
||||||
errno,
|
errno,
|
||||||
")");
|
")");
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
})
|
||||||
}
|
.detach());
|
||||||
|
|
||||||
success = connect(sock, address.get(), address.len);
|
|
||||||
if (success == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (success < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[ring] Couldn't connect (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockets.push_back(sock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sockets;
|
return sockets;
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
struct SumOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output += *input;
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MaxOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::max(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct MinOp {
|
|
||||||
void operator()(const T* input, T* output, size_t N) {
|
|
||||||
while (N-- > 0) {
|
|
||||||
*output = std::min(*output, *input);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class RingGroup : public GroupImpl {
|
class RingGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
RingGroup(int rank, std::vector<std::vector<address_t>> nodes, bool verbose)
|
RingGroup(
|
||||||
|
int rank,
|
||||||
|
std::vector<std::vector<detail::address_t>> nodes,
|
||||||
|
bool verbose)
|
||||||
: rank_(rank), verbose_(verbose), pool_(0) {
|
: rank_(rank), verbose_(verbose), pool_(0) {
|
||||||
if (rank_ > 0 && rank_ >= nodes.size()) {
|
if (rank_ > 0 && rank_ >= nodes.size()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -633,17 +475,17 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(
|
SWITCH_TYPE(
|
||||||
output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()));
|
output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
|||||||
204
mlx/distributed/utils.cpp
Normal file
204
mlx/distributed/utils.cpp
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "mlx/distributed/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port) {
|
||||||
|
struct addrinfo hints, *res;
|
||||||
|
std::memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_UNSPEC;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
|
||||||
|
int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
|
||||||
|
if (status != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip << ":" << port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
address_t result;
|
||||||
|
memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
|
||||||
|
result.len = res->ai_addrlen;
|
||||||
|
freeaddrinfo(res);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port) {
|
||||||
|
auto colon = ip_port.find(":");
|
||||||
|
if (colon == std::string::npos) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Can't parse address " << ip_port;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
std::string ip(ip_port.begin(), ip_port.begin() + colon);
|
||||||
|
std::string port(ip_port.begin() + colon + 1, ip_port.end());
|
||||||
|
|
||||||
|
return parse_address(ip, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(const char* tag) {
|
||||||
|
sock_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock_ < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(TCPSocket&& s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
|
||||||
|
if (this != &s) {
|
||||||
|
sock_ = s.sock_;
|
||||||
|
s.sock_ = -1;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket::TCPSocket(int s) : sock_(s) {}
|
||||||
|
|
||||||
|
TCPSocket::~TCPSocket() {
|
||||||
|
if (sock_ > 0) {
|
||||||
|
shutdown(sock_, 2);
|
||||||
|
close(sock_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int TCPSocket::detach() {
|
||||||
|
int s = sock_;
|
||||||
|
sock_ = -1;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::listen(const char* tag, const address_t& addr) {
|
||||||
|
int success;
|
||||||
|
|
||||||
|
// Make sure we can launch immediately after shutdown by setting the
|
||||||
|
// reuseaddr option so that we don't get address already in use errors
|
||||||
|
int enable = 1;
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the address and port
|
||||||
|
success = bind(sock_, addr.get(), addr.len);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't bind socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare waiting for connections
|
||||||
|
success = ::listen(sock_, 0);
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't listen (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::accept(const char* tag) {
|
||||||
|
int peer = ::accept(sock_, nullptr, nullptr);
|
||||||
|
if (peer < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Accept failed (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::send(const char* tag, const void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::send(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Send failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<const char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TCPSocket::recv(const char* tag, void* data, size_t len) {
|
||||||
|
while (len > 0) {
|
||||||
|
auto n = ::recv(sock_, data, len, 0);
|
||||||
|
if (n <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Recv failed with errno=" << errno;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
len -= n;
|
||||||
|
data = static_cast<char*>(data) + n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TCPSocket TCPSocket::connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries,
|
||||||
|
int wait,
|
||||||
|
std::function<void(int, int)> cb) {
|
||||||
|
int sock, success;
|
||||||
|
|
||||||
|
// Attempt to connect `num_retries` times with exponential backoff.
|
||||||
|
for (int attempt = 0; attempt < num_retries; attempt++) {
|
||||||
|
// Create the socket
|
||||||
|
sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't create socket to connect (error: " << errno
|
||||||
|
<< ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
success = ::connect(sock, addr.get(), addr.len);
|
||||||
|
if (success == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(attempt, wait);
|
||||||
|
if (wait > 0) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(wait));
|
||||||
|
}
|
||||||
|
|
||||||
|
wait <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (success < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Couldn't connect (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return TCPSocket(sock);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
67
mlx/distributed/utils.h
Normal file
67
mlx/distributed/utils.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
struct address_t {
|
||||||
|
sockaddr_storage addr;
|
||||||
|
socklen_t len;
|
||||||
|
|
||||||
|
const sockaddr* get() const {
|
||||||
|
return (struct sockaddr*)&addr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr from an ip and port provided as strings.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip, const std::string& port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse a sockaddr provided as an <ip>:<port> string.
|
||||||
|
*/
|
||||||
|
address_t parse_address(const std::string& ip_port);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Small wrapper over a TCP socket to simplify initiating connections.
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
TCPSocket(const char* tag);
|
||||||
|
TCPSocket(const TCPSocket&) = delete;
|
||||||
|
TCPSocket& operator=(const TCPSocket&) = delete;
|
||||||
|
TCPSocket(TCPSocket&& s);
|
||||||
|
TCPSocket& operator=(TCPSocket&&);
|
||||||
|
~TCPSocket();
|
||||||
|
|
||||||
|
void listen(const char* tag, const address_t& addr);
|
||||||
|
TCPSocket accept(const char* tag);
|
||||||
|
|
||||||
|
void send(const char* tag, const void* data, size_t len);
|
||||||
|
void recv(const char* tag, void* data, size_t len);
|
||||||
|
|
||||||
|
int detach();
|
||||||
|
|
||||||
|
operator int() const {
|
||||||
|
return sock_;
|
||||||
|
}
|
||||||
|
|
||||||
|
static TCPSocket connect(
|
||||||
|
const char* tag,
|
||||||
|
const address_t& addr,
|
||||||
|
int num_retries = 1,
|
||||||
|
int wait = 0,
|
||||||
|
std::function<void(int, int)> cb = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
TCPSocket(int sock);
|
||||||
|
|
||||||
|
int sock_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
||||||
@@ -880,11 +880,6 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
|||||||
|
|
||||||
std::vector<array> returned_vjps;
|
std::vector<array> returned_vjps;
|
||||||
for (int arg : argnums) {
|
for (int arg : argnums) {
|
||||||
if (arg >= 3) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[scale_dot_product_attention] Does not support VJP with respect "
|
|
||||||
" to mask or attention sinks.");
|
|
||||||
}
|
|
||||||
returned_vjps.push_back(std::move(vjps[arg]));
|
returned_vjps.push_back(std::move(vjps[arg]));
|
||||||
}
|
}
|
||||||
return returned_vjps;
|
return returned_vjps;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
"setuptools>=80",
|
"setuptools>=80",
|
||||||
"nanobind==2.10.2",
|
"nanobind==2.4.0",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
85
python/mlx/_distributed_utils/common.py
Normal file
85
python/mlx/_distributed_utils/common.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Host:
|
||||||
|
rank: int
|
||||||
|
ssh_hostname: str
|
||||||
|
ips: list[str]
|
||||||
|
rdma: list[Optional[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def positive_number(x):
|
||||||
|
x = int(x)
|
||||||
|
if x <= 0:
|
||||||
|
raise ValueError("Number should be positive")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log(verbose, *args, **kwargs):
|
||||||
|
if not verbose:
|
||||||
|
return
|
||||||
|
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_warning(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_error(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostlist(parser, hostlist, repeats):
|
||||||
|
hosts = []
|
||||||
|
for i, h in enumerate(hostlist.split(",")):
|
||||||
|
if h == "":
|
||||||
|
raise ValueError("Hostname cannot be empty")
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(h)
|
||||||
|
ips = [h]
|
||||||
|
except ValueError:
|
||||||
|
ips = []
|
||||||
|
for i in range(repeats):
|
||||||
|
hosts.append(Host(i, h, ips))
|
||||||
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostfile(parser, hostfile):
|
||||||
|
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||||
|
the ips to communicate over when using the ring backend.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
|
||||||
|
...
|
||||||
|
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostfile (str): The path to the json file containing the host
|
||||||
|
information
|
||||||
|
"""
|
||||||
|
hostfile = Path(hostfile)
|
||||||
|
if not hostfile.exists():
|
||||||
|
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hosts = []
|
||||||
|
with open(hostfile) as f:
|
||||||
|
for i, h in enumerate(json.load(f)):
|
||||||
|
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
|
||||||
|
return hosts
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||||
0
python/mlx/_distributed_utils/config.py
Normal file
0
python/mlx/_distributed_utils/config.py
Normal file
@@ -832,7 +832,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@@ -903,6 +903,8 @@ def main():
|
|||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
if args.backend == "nccl":
|
if args.backend == "nccl":
|
||||||
launch_nccl(parser, hosts, args, rest)
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
if args.backend == "jaccl":
|
||||||
|
launch_jaccl(parser, hosts, args, rest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
540
python/mlx/_distributed_utils/launch.py
Normal file
540
python/mlx/_distributed_utils/launch.py
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from collections import Counter
|
||||||
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty as QueueEmpty
|
||||||
|
from queue import Queue
|
||||||
|
from select import select
|
||||||
|
from subprocess import PIPE, Popen, run
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
|
||||||
|
|
||||||
|
|
||||||
|
class CommandProcess:
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
"""Return the Popen object that refers to the current command."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
"""Return a tuple (returncode, killed) for the command. It should be
|
||||||
|
(None, None) while the command is running normally."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def preprocess_output(self, data: str, is_stdout=False):
|
||||||
|
"""Preprocess the output of the command so that extra data can be
|
||||||
|
capture or the format changed on the fly."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
"""Terminate or return the exit code."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteProcess(CommandProcess):
|
||||||
|
def __init__(self, rank, host, cwd, files, env, command):
|
||||||
|
is_local = host == "127.0.0.1"
|
||||||
|
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
||||||
|
script_b64 = base64.b64encode(script.encode()).decode()
|
||||||
|
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||||
|
if not is_local:
|
||||||
|
cmd = f"ssh {host} '{cmd}'"
|
||||||
|
|
||||||
|
self._host = host
|
||||||
|
self._pidfile = None
|
||||||
|
self._is_local = is_local
|
||||||
|
self._process = Popen(
|
||||||
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
stdin=PIPE,
|
||||||
|
stdout=PIPE,
|
||||||
|
stderr=PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._killed = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
return self._process
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
return self._process.poll(), self._killed
|
||||||
|
|
||||||
|
def preprocess_output(self, data, is_stdout=False):
|
||||||
|
if self._pidfile is None:
|
||||||
|
pidfile, *rest = data.split("\n", maxsplit=1)
|
||||||
|
self._pidfile = pidfile
|
||||||
|
return rest[0] if rest else ""
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
if self._killed:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.wait()
|
||||||
|
|
||||||
|
# Kill the remote program if possible
|
||||||
|
cmd = ""
|
||||||
|
cmd += f"pid=$(cat {self._pidfile}); "
|
||||||
|
cmd += "if ps -p $pid >/dev/null; then "
|
||||||
|
cmd += " kill $pid; "
|
||||||
|
cmd += " echo 1; "
|
||||||
|
cmd += "else "
|
||||||
|
cmd += " echo 0; "
|
||||||
|
cmd += "fi; "
|
||||||
|
cmd += f"rm {self._pidfile}"
|
||||||
|
if not self._is_local:
|
||||||
|
cmd = f"ssh {self._host} '{cmd}'"
|
||||||
|
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||||
|
|
||||||
|
self._killed = c.stdout.strip() == "1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_monitor_script(rank, cwd, files, env, command):
|
||||||
|
# Imports that are used throughout
|
||||||
|
script = ""
|
||||||
|
script += "import os\n"
|
||||||
|
script += "import sys\n"
|
||||||
|
script += "import tempfile\n"
|
||||||
|
script += "from pathlib import Path\n"
|
||||||
|
|
||||||
|
# Write the PID to a file so we can kill the process if needed
|
||||||
|
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||||
|
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||||
|
script += "print(pidfile, flush=True)\n"
|
||||||
|
|
||||||
|
# Change the working directory if one was requested. Otherwise attempt to
|
||||||
|
# change to the current one but don't fail if it wasn't possible.
|
||||||
|
d = cwd or os.getcwd()
|
||||||
|
script += f"if Path({repr(d)}).exists():\n"
|
||||||
|
script += f" os.chdir({repr(d)})\n"
|
||||||
|
if cwd is not None:
|
||||||
|
script += "else:\n"
|
||||||
|
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||||
|
script += f" sys.exit(1)\n"
|
||||||
|
|
||||||
|
# Add the environment variables that were requested
|
||||||
|
script += "env = dict(os.environ)\n"
|
||||||
|
for e in env:
|
||||||
|
key, *value = e.split("=", maxsplit=1)
|
||||||
|
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||||
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
|
log_warning(
|
||||||
|
f"'{e}' is an invalid environment variable so it is ignored"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||||
|
|
||||||
|
# Make the temporary files
|
||||||
|
for env_name, content in files.items():
|
||||||
|
script += "_, fname = tempfile.mkstemp()\n"
|
||||||
|
script += "with open(fname, 'w') as f:\n"
|
||||||
|
script += f" f.write({repr(content)})\n"
|
||||||
|
script += f"env[{repr(env_name)}] = fname\n"
|
||||||
|
|
||||||
|
# Finally add the rank
|
||||||
|
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||||
|
script += "\n"
|
||||||
|
|
||||||
|
# Replace the process with the script
|
||||||
|
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||||
|
script += "os.execve(command[0], command, env)\n"
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_with_io(command_class, arguments, verbose):
|
||||||
|
stop = False
|
||||||
|
exit_codes = [(None, None)] * len(arguments)
|
||||||
|
|
||||||
|
def _thread_fn(rank, *args, **kwargs):
|
||||||
|
stdin_queue = kwargs.pop("stdin_queue")
|
||||||
|
stdout_queue = kwargs.pop("stdout_queue")
|
||||||
|
stderr_queue = kwargs.pop("stderr_queue")
|
||||||
|
|
||||||
|
command = command_class(rank, *args, **kwargs)
|
||||||
|
p = command.process
|
||||||
|
os.set_blocking(p.stdout.fileno(), False)
|
||||||
|
os.set_blocking(p.stderr.fileno(), False)
|
||||||
|
os.set_blocking(p.stdin.fileno(), False)
|
||||||
|
|
||||||
|
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||||
|
to_write = [p.stdin.fileno()]
|
||||||
|
|
||||||
|
stdin_buffer = b""
|
||||||
|
while p.poll() is None:
|
||||||
|
try:
|
||||||
|
stdin_buffer += stdin_queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||||
|
for fd in rlist:
|
||||||
|
is_stdout = fd == p.stdout.fileno()
|
||||||
|
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||||
|
msg = command.preprocess_output(msg, is_stdout)
|
||||||
|
if is_stdout:
|
||||||
|
stdout_queue.put(msg.encode())
|
||||||
|
else:
|
||||||
|
stderr_queue.put(msg.encode())
|
||||||
|
for fd in wlist:
|
||||||
|
if len(stdin_buffer) > 0:
|
||||||
|
n = os.write(fd, stdin_buffer)
|
||||||
|
stdin_buffer = stdin_buffer[n:]
|
||||||
|
if stop:
|
||||||
|
command.terminate()
|
||||||
|
break
|
||||||
|
exit_codes[rank] = command.exit_status
|
||||||
|
|
||||||
|
if exit_codes[rank][1]:
|
||||||
|
log_warning(f"Node with rank {rank} was killed")
|
||||||
|
elif exit_codes[rank][0] != 0:
|
||||||
|
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
|
||||||
|
else:
|
||||||
|
log(verbose, f"Node with rank {rank} completed")
|
||||||
|
|
||||||
|
stdin_queues = []
|
||||||
|
stdout_queues = []
|
||||||
|
stderr_queues = []
|
||||||
|
threads = []
|
||||||
|
for i, (args, kwargs) in enumerate(arguments):
|
||||||
|
stdin_queues.append(Queue())
|
||||||
|
stdout_queues.append(Queue())
|
||||||
|
stderr_queues.append(Queue())
|
||||||
|
t = threading.Thread(
|
||||||
|
target=_thread_fn,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs
|
||||||
|
| {
|
||||||
|
"stdin_queue": stdin_queues[-1],
|
||||||
|
"stdout_queue": stdout_queues[-1],
|
||||||
|
"stderr_queue": stderr_queues[-1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
os.set_blocking(sys.stdin.fileno(), False)
|
||||||
|
os.set_blocking(sys.stdout.fileno(), True)
|
||||||
|
os.set_blocking(sys.stderr.fileno(), True)
|
||||||
|
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
|
||||||
|
# Broadcast user input to the jobs
|
||||||
|
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
|
||||||
|
for fd in rlist:
|
||||||
|
stdin_buffer = os.read(fd, 8192)
|
||||||
|
for q in stdin_queues:
|
||||||
|
q.put(stdin_buffer)
|
||||||
|
|
||||||
|
# Gather job output
|
||||||
|
for q in stdout_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
for q in stderr_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
# Check if all are running and terminate otherwise
|
||||||
|
if any(t.is_alive() for t in threads):
|
||||||
|
for i, t in enumerate(threads):
|
||||||
|
if not t.is_alive():
|
||||||
|
if exit_codes[i][0] != 0:
|
||||||
|
stop = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait for the jobs to finish
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Process any remaining outputs
|
||||||
|
for q in stdout_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get())
|
||||||
|
for q in stderr_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get())
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_ring(parser, hosts, args, command):
|
||||||
|
if any(len(h.ips) == 0 for h in hosts):
|
||||||
|
parser.error(
|
||||||
|
"The ring backend requires IPs to be provided instead of hostnames"
|
||||||
|
)
|
||||||
|
|
||||||
|
port = args.starting_port
|
||||||
|
ring_hosts = []
|
||||||
|
for h in hosts:
|
||||||
|
node = []
|
||||||
|
for ip in h.ips:
|
||||||
|
for i in range(args.connections_per_ip):
|
||||||
|
node.append(f"{ip}:{port}")
|
||||||
|
port += 1
|
||||||
|
ring_hosts.append(node)
|
||||||
|
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||||
|
|
||||||
|
files = {"MLX_HOSTFILE": hostfile}
|
||||||
|
env = args.env
|
||||||
|
if args.verbose:
|
||||||
|
env.append("MLX_RING_VERBOSE=1")
|
||||||
|
cwd = args.cwd
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = len(hosts)
|
||||||
|
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
if args.verbose:
|
||||||
|
env.append("NCCL_DEBUG=INFO")
|
||||||
|
env.append(f"NCCL_HOST_IP={master_host}")
|
||||||
|
env.append(f"NCCL_PORT={master_port}")
|
||||||
|
env.append(f"MLX_WORLD_SIZE={world_size}")
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(
|
||||||
|
rank,
|
||||||
|
h.ssh_hostname,
|
||||||
|
cwd,
|
||||||
|
{},
|
||||||
|
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||||
|
command,
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_jaccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
|
||||||
|
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
|
||||||
|
if not have_rdmas or not have_nulls:
|
||||||
|
raise ValueError("Malformed hostfile for jaccl backend")
|
||||||
|
|
||||||
|
coordinator = hosts[0].ips[0]
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
|
||||||
|
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mpi_libname():
|
||||||
|
try:
|
||||||
|
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||||
|
ompi_info = ompi_info.stdout.strip().decode()
|
||||||
|
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
otool_output = run(
|
||||||
|
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||||
|
otool_output = otool_output.stdout.decode()
|
||||||
|
|
||||||
|
# StopIteration if not found
|
||||||
|
libmpi_line = next(
|
||||||
|
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||||
|
)
|
||||||
|
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def launch_mpi(parser, hosts, args, command):
|
||||||
|
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||||
|
mpirun = mpirun.stdout.strip().decode()
|
||||||
|
|
||||||
|
# Compatibility with homebrew and pip installs
|
||||||
|
mpi_libname = get_mpi_libname()
|
||||||
|
if mpi_libname is not None:
|
||||||
|
dyld = Path(mpirun).parent.parent / "lib"
|
||||||
|
args.env = [
|
||||||
|
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||||
|
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||||
|
] + args.env
|
||||||
|
|
||||||
|
log(args.verbose, f"Using '{mpirun}'")
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||||
|
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||||
|
for h, n in hosts.items():
|
||||||
|
print(f"{h} slots={n}", file=f)
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
mpirun,
|
||||||
|
"--output",
|
||||||
|
":raw", # do not line buffer output
|
||||||
|
"--hostfile",
|
||||||
|
f.name,
|
||||||
|
*(["-cwd", args.cwd] if args.cwd else []),
|
||||||
|
*sum((["-x", e] for e in args.env), []),
|
||||||
|
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||||
|
"--",
|
||||||
|
*command,
|
||||||
|
]
|
||||||
|
log(args.verbose, "Running", " ".join(cmd))
|
||||||
|
try:
|
||||||
|
run(cmd)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-python",
|
||||||
|
action="store_true",
|
||||||
|
help="Print the path to the current python executable and exit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-hosts",
|
||||||
|
"-n",
|
||||||
|
type=positive_number,
|
||||||
|
default=1,
|
||||||
|
help="Repeat each host a given number of times",
|
||||||
|
)
|
||||||
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||||
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
|
help="Which distributed backend to launch",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Set environment variables for the jobs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mpi-arg",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Arguments to pass directly to mpirun",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--connections-per-ip",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="How many connections per ip to use for the ring backend",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--starting-port",
|
||||||
|
"-p",
|
||||||
|
type=int,
|
||||||
|
default=32323,
|
||||||
|
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.print_python:
|
||||||
|
print(sys.executable)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(rest) == 0:
|
||||||
|
parser.error("No script is provided")
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
|
# Try to extract a list of hosts and corresponding ips
|
||||||
|
if args.hostfile is not None:
|
||||||
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
else:
|
||||||
|
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||||
|
|
||||||
|
# Check if the script is a file and convert it to a full path
|
||||||
|
if (script := Path(rest[0])).exists() and script.is_file():
|
||||||
|
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||||
|
elif (command := shutil.which(rest[0])) is not None:
|
||||||
|
rest[0] = command
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||||
|
|
||||||
|
# Launch
|
||||||
|
if args.backend == "ring":
|
||||||
|
launch_ring(parser, hosts, args, rest)
|
||||||
|
if args.backend == "mpi":
|
||||||
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
if args.backend == "jaccl":
|
||||||
|
launch_jaccl(parser, hosts, args, rest)
|
||||||
@@ -52,9 +52,25 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"is_available",
|
"is_available",
|
||||||
&mx::distributed::is_available,
|
[](const std::string& backend) {
|
||||||
|
return mx::distributed::is_available(backend);
|
||||||
|
},
|
||||||
|
"backend"_a = "any",
|
||||||
|
nb::sig("def is_available(backend: str = 'any') -> bool"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Check if a communication backend is available.
|
Check if a communication backend is available.
|
||||||
|
|
||||||
|
Note, this function returns whether MLX has the capability of
|
||||||
|
instantiating that distributed backend not whether it is possible to
|
||||||
|
create a communication group. For that purpose one should use
|
||||||
|
``init(strict=True)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (str, optional): The name of the backend to check for availability.
|
||||||
|
It takes the same values as ``init()``. Default: ``any``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the distributed backend is available.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
@@ -79,10 +95,10 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``nccl``, ``jaccl``, ``any``. If
|
||||||
available backends are tried and the first one that succeeds
|
set to ``any`` all available backends are tried and the first one
|
||||||
becomes the global group which will be returned in subsequent
|
that succeeds becomes the global group which will be returned in
|
||||||
calls. Default: ``any``
|
subsequent calls. Default: ``any``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Group: The group representing all the launched processes.
|
Group: The group representing all the launched processes.
|
||||||
|
|||||||
@@ -89,8 +89,7 @@ static PyType_Spec gc_func_spec = {
|
|||||||
/* .name = */ "mlx.gc_func",
|
/* .name = */ "mlx.gc_func",
|
||||||
/* .basicsize = */ (int)sizeof(gc_func),
|
/* .basicsize = */ (int)sizeof(gc_func),
|
||||||
/* .itemsize = */ 0,
|
/* .itemsize = */ 0,
|
||||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
|
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
|
||||||
Py_TPFLAGS_HAVE_VECTORCALL,
|
|
||||||
/* .slots = */ gc_func_slots};
|
/* .slots = */ gc_func_slots};
|
||||||
|
|
||||||
static PyTypeObject* gc_func_tp = nullptr;
|
static PyTypeObject* gc_func_tp = nullptr;
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
|
|||||||
|
|
||||||
NB_TYPE_CASTER(
|
NB_TYPE_CASTER(
|
||||||
List,
|
List,
|
||||||
const_name("tuple[") + make_caster<Type>::Name + const_name(", ...]"))
|
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
|
||||||
|
const_name(", ...]"))
|
||||||
|
|
||||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||||
size_t size;
|
size_t size;
|
||||||
|
|||||||
@@ -124,53 +124,37 @@ auto py_value_and_grad(
|
|||||||
|
|
||||||
// Collect the arrays
|
// Collect the arrays
|
||||||
std::vector<mx::array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
std::vector<nb::object> array_objects;
|
|
||||||
auto flatten_with_objects = [&arrays, &array_objects](
|
|
||||||
auto tree, bool strict) {
|
|
||||||
tree_visit(tree, [&](nb::handle obj) {
|
|
||||||
if (nb::isinstance<mx::array>(obj)) {
|
|
||||||
arrays.push_back(nb::cast<mx::array>(obj));
|
|
||||||
array_objects.push_back(nb::borrow<nb::object>(obj));
|
|
||||||
} else if (strict) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[tree_flatten] The argument should contain only arrays");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<int> counts(1, 0);
|
std::vector<int> counts(1, 0);
|
||||||
std::vector<int> gradient_indices;
|
std::vector<int> gradient_indices;
|
||||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||||
auto pre_size = arrays.size();
|
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||||
flatten_with_objects(args[i], /* strict = */ needs_grad);
|
|
||||||
if (needs_grad) {
|
if (needs_grad) {
|
||||||
auto old_size = gradient_indices.size();
|
auto old_size = gradient_indices.size();
|
||||||
auto delta_size = arrays.size() - pre_size;
|
gradient_indices.resize(old_size + argsi.size());
|
||||||
gradient_indices.resize(old_size + delta_size);
|
|
||||||
std::iota(
|
std::iota(
|
||||||
gradient_indices.begin() + old_size,
|
gradient_indices.begin() + old_size,
|
||||||
gradient_indices.end(),
|
gradient_indices.end(),
|
||||||
pre_size);
|
arrays.size());
|
||||||
j++;
|
j++;
|
||||||
counts.push_back(delta_size);
|
counts.push_back(argsi.size());
|
||||||
}
|
}
|
||||||
|
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||||
}
|
}
|
||||||
for (auto item : kwargs) {
|
for (auto item : kwargs) {
|
||||||
bool needs_grad =
|
bool needs_grad =
|
||||||
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
||||||
auto pre_size = arrays.size();
|
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
|
||||||
flatten_with_objects(item.second, /* strict = */ needs_grad);
|
|
||||||
if (needs_grad) {
|
if (needs_grad) {
|
||||||
auto old_size = gradient_indices.size();
|
auto old_size = gradient_indices.size();
|
||||||
auto delta_size = arrays.size() - pre_size;
|
gradient_indices.resize(old_size + argsk.size());
|
||||||
gradient_indices.resize(old_size + delta_size);
|
|
||||||
std::iota(
|
std::iota(
|
||||||
gradient_indices.begin() + old_size,
|
gradient_indices.begin() + old_size,
|
||||||
gradient_indices.end(),
|
gradient_indices.end(),
|
||||||
pre_size);
|
arrays.size());
|
||||||
counts.push_back(delta_size);
|
counts.push_back(argsk.size());
|
||||||
}
|
}
|
||||||
|
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||||
}
|
}
|
||||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||||
|
|
||||||
@@ -179,7 +163,7 @@ auto py_value_and_grad(
|
|||||||
nb::object py_value_out;
|
nb::object py_value_out;
|
||||||
auto value_and_grads = mx::value_and_grad(
|
auto value_and_grads = mx::value_and_grad(
|
||||||
[&fun,
|
[&fun,
|
||||||
&array_objects,
|
&arrays,
|
||||||
&args,
|
&args,
|
||||||
&kwargs,
|
&kwargs,
|
||||||
&py_value_out,
|
&py_value_out,
|
||||||
@@ -199,9 +183,8 @@ auto py_value_and_grad(
|
|||||||
tree_visit_update(tree, [&](nb::handle node) {
|
tree_visit_update(tree, [&](nb::handle node) {
|
||||||
auto replace_arr = nb::cast<mx::array>(node);
|
auto replace_arr = nb::cast<mx::array>(node);
|
||||||
if (replace_arr.id() == a[index].id()) {
|
if (replace_arr.id() == a[index].id()) {
|
||||||
return array_objects[index++];
|
return nb::cast(arrays[index++]);
|
||||||
} else {
|
} else {
|
||||||
index++;
|
|
||||||
return nb::cast(replace_arr);
|
return nb::cast(replace_arr);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -780,21 +780,9 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
return arrs[0]
|
return arrs[0]
|
||||||
|
|
||||||
arrs = [mx.array(1.0)]
|
arrs = [mx.array(1.0)]
|
||||||
arr = arrs[0]
|
init_id = id(arrs[0])
|
||||||
mx.grad(fun)(arrs)
|
mx.grad(fun)(arrs)
|
||||||
self.assertEqual(id(arr), id(arrs[0]))
|
self.assertEqual(init_id, id(arrs[0]))
|
||||||
|
|
||||||
def fun(arrs):
|
|
||||||
arrs[1] = sum(arrs)
|
|
||||||
return arrs[1]
|
|
||||||
|
|
||||||
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
|
|
||||||
a_0, a_1, a_2 = arrs
|
|
||||||
|
|
||||||
mx.grad(fun)(arrs)
|
|
||||||
self.assertEqual(id(a_0), id(arrs[0]))
|
|
||||||
self.assertNotEqual(id(a_1), id(arrs[1]))
|
|
||||||
self.assertEqual(id(a_2), id(arrs[2]))
|
|
||||||
|
|
||||||
def test_grad_with_inplace_update(self):
|
def test_grad_with_inplace_update(self):
|
||||||
def loss_fn(model):
|
def loss_fn(model):
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import gc
|
|||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
import math
|
import math
|
||||||
|
import unittest
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompile(mlx_tests.MLXTestCase):
|
class TestCompile(mlx_tests.MLXTestCase):
|
||||||
@@ -1252,26 +1252,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
loss, grads = step(emb, w, x)
|
loss, grads = step(emb, w, x)
|
||||||
mx.eval(loss, grads)
|
mx.eval(loss, grads)
|
||||||
|
|
||||||
def test_compile_donates_input_buffer(self):
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
def fun(x):
|
|
||||||
return mx.sin(x) + 1
|
|
||||||
|
|
||||||
compiled_fn = mx.compile(fun)
|
|
||||||
|
|
||||||
input = mx.arange(16, dtype=mx.float32)
|
|
||||||
mx.eval(input)
|
|
||||||
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
|
|
||||||
|
|
||||||
out = compiled_fn(input)
|
|
||||||
del input # Ensure the reference is dropped
|
|
||||||
mx.eval(out)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -744,6 +744,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
return Vector([t[0] + 10, t[1] * 10])
|
return Vector([t[0] + 10, t[1] * 10])
|
||||||
|
|
||||||
x = State(mx.array(1), mx.array(2))
|
x = State(mx.array(1), mx.array(2))
|
||||||
|
print(f"{transform(x)=}")
|
||||||
|
|
||||||
vmap_transform = mx.vmap(transform)
|
vmap_transform = mx.vmap(transform)
|
||||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||||
|
|||||||
6
setup.py
6
setup.py
@@ -255,7 +255,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"dev": [
|
"dev": [
|
||||||
"nanobind==2.10.2",
|
"nanobind==2.4.0",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pre-commit",
|
"pre-commit",
|
||||||
"setuptools>=80",
|
"setuptools>=80",
|
||||||
@@ -265,8 +265,8 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
entry_points = {
|
entry_points = {
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx.launch = mlx.distributed_run:main",
|
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
install_requires = []
|
install_requires = []
|
||||||
|
|||||||
Reference in New Issue
Block a user