mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
sdpav-back
...
984cefb14d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
984cefb14d | ||
|
|
dadf8d9c93 | ||
|
|
389276e2b8 | ||
|
|
2e255c8eb4 | ||
|
|
062aa80b84 | ||
|
|
f540b1d612 | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 |
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -7,17 +7,17 @@ Exporting Functions
|
|||||||
|
|
||||||
MLX has an API to export and import functions to and from a file. This lets you
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
front-end (e.g. C++).
|
front-end (e.g. C++).
|
||||||
|
|
||||||
This guide walks through the basics of the MLX export API with some examples.
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
To see the full list of functions check-out the :ref:`API documentation
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
<export>`.
|
<export>`.
|
||||||
|
|
||||||
Basics of Exporting
|
Basics of Exporting
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
Let's start with a simple example:
|
Let's start with a simple example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
|||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
|
||||||
# Both arguments to fun are positional
|
# Both arguments to fun are positional
|
||||||
mx.export_function("add.mlxfn", fun, x, y)
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
|||||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
they are evaluated. The computation graph that gets exported will include
|
they are evaluated. The computation graph that gets exported will include
|
||||||
the computation that produces enclosed inputs.
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
exported function would include the random initialization of the
|
exported function would include the random initialization of the
|
||||||
:obj:`mlx.nn.Module` parameters.
|
:obj:`mlx.nn.Module` parameters.
|
||||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
# Set the model's parameters to the input parameters
|
# Set the model's parameters to the input parameters
|
||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array(-1.0))
|
out, = imported_abs(mx.array(-1.0))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
With ``shapeless=False`` (which is the default), the second call to
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
def fun(x, y=None):
|
def fun(x, y=None):
|
||||||
constant = mx.array(3.0)
|
constant = mx.array(3.0)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x += y
|
x += y
|
||||||
return x + constant
|
return x + constant
|
||||||
|
|
||||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
In the above example the function constant data, (i.e. ``constant``), is only
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
saved once.
|
saved once.
|
||||||
|
|
||||||
Transformations with Imported Functions
|
Transformations with Imported Functions
|
||||||
---------------------------------------
|
---------------------------------------
|
||||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
|||||||
# Prints: array(1, dtype=float32)
|
# Prints: array(1, dtype=float32)
|
||||||
print(dfdx(x))
|
print(dfdx(x))
|
||||||
|
|
||||||
# Compile the imported function
|
# Compile the imported function
|
||||||
mx.compile(imported_fun)
|
mx.compile(imported_fun)
|
||||||
# Prints: array(0, dtype=float32)
|
# Prints: array(0, dtype=float32)
|
||||||
print(compiled_fun(x)[0])
|
print(compiled_fun(x)[0])
|
||||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
|||||||
// Prints: array(2, dtype=float32)
|
// Prints: array(2, dtype=float32)
|
||||||
std::cout << outputs[0] << std::endl;
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
Imported functions can be transformed in C++ just like in Python. Use
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
@@ -39,6 +40,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
|
|||||||
51
mlx/backend/cuda/distributed.cu
Normal file
51
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
namespace distributed {
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& input = inputs[0];
|
||||||
|
auto& output = outputs[0];
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
|
||||||
|
if (input.is_donatable()) {
|
||||||
|
output.copy_shared_buffer(input);
|
||||||
|
} else {
|
||||||
|
output.set_data(allocator::malloc(output.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), input, output, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, max, and min are supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -6,17 +6,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
bool has_mask,
|
|
||||||
bool has_arr_mask,
|
|
||||||
bool do_causal,
|
|
||||||
Stream s) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -53,12 +42,10 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU(ScaledDotProductAttention)
|
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
|||||||
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
@@ -0,0 +1,781 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
#define PRAGMA_LOOP_UNROLL #pragma unroll
|
||||||
|
|
||||||
|
struct AttnParams {
|
||||||
|
int B;
|
||||||
|
int H;
|
||||||
|
int D;
|
||||||
|
|
||||||
|
int qL;
|
||||||
|
int kL;
|
||||||
|
|
||||||
|
int gqa_factor;
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
int64_t Q_strides[3];
|
||||||
|
int64_t K_strides[3];
|
||||||
|
int64_t V_strides[3];
|
||||||
|
int64_t O_strides[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_1pass(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -INFINITY;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = max_scores[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
sum_exp_score =
|
||||||
|
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||||
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_1(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
float* partials,
|
||||||
|
float* sums,
|
||||||
|
float* maxs,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 8;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z / blocks;
|
||||||
|
const int block_idx = blockIdx.z % blocks;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = block_idx * BN + warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s + // Sequence
|
||||||
|
block_idx; // Block
|
||||||
|
|
||||||
|
partials += p_offset * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -1e9;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||||
|
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||||
|
|
||||||
|
// Write the sum and new max
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
sums[0] = sum_exp_score;
|
||||||
|
maxs[0] = new_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
U ot = outputs[0][lane_idx];
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 1; j < BN; j++) {
|
||||||
|
ot += outputs[j][lane_idx];
|
||||||
|
warp.sync();
|
||||||
|
}
|
||||||
|
o[i] = ot;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
partials[v_per_thread * lane_idx + i] = o[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_2(
|
||||||
|
const float* partials,
|
||||||
|
const float* sums,
|
||||||
|
const float* maxs,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U o[v_per_thread];
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s; // Sequence
|
||||||
|
|
||||||
|
partials += p_offset * D + warp_idx * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
U max_score = maxs[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||||
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = partials[v_per_thread * lane_idx + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_headdim(int n, F&& f) {
|
||||||
|
switch (n) {
|
||||||
|
case 64:
|
||||||
|
f(std::integral_constant<int, 64>{});
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
f(std::integral_constant<int, 96>{});
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
f(std::integral_constant<int, 128>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_1pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_2pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
// Allocate the intermediates
|
||||||
|
int blocks = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.reserve(o.ndim() + 1);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
|
||||||
|
intermediate_shape.push_back(blocks);
|
||||||
|
intermediate_shape.push_back(o.shape().back());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||||
|
intermediate_shape.pop_back();
|
||||||
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||||
|
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.add_temporary(sums);
|
||||||
|
encoder.add_temporary(maxs);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel = cu::
|
||||||
|
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
encoder.set_output_array(sums);
|
||||||
|
encoder.set_output_array(maxs);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B * 32);
|
||||||
|
dim3 block_dim(8 * 32, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel = cu::
|
||||||
|
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_input_array(sums);
|
||||||
|
encoder.set_input_array(maxs);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
0,
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
if (kL > 1024) {
|
||||||
|
return sdpa_vector_2pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
} else {
|
||||||
|
return sdpa_vector_1pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
if (detail::in_grad_tracing()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (s.device == Device::cpu) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int value_head_dim = v.shape(-1);
|
||||||
|
const int query_head_dim = q.shape(-1);
|
||||||
|
const int query_sequence_length = q.shape(2);
|
||||||
|
const int key_sequence_length = k.shape(2);
|
||||||
|
|
||||||
|
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
|
|
||||||
|
const bool supported_vector_config =
|
||||||
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
|
const bool supported_config = supported_vector_config;
|
||||||
|
|
||||||
|
return has_arr_mask || !supported_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out) {
|
||||||
|
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
auto& q_pre = inputs[0];
|
||||||
|
auto& k_pre = inputs[1];
|
||||||
|
auto& v_pre = inputs[2];
|
||||||
|
auto& o = out;
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
|
||||||
|
// Define some copy functions to ensure the layout of the inputs is as
|
||||||
|
// expected.
|
||||||
|
copies.reserve(3);
|
||||||
|
auto copy_unless = [&copies, &s](
|
||||||
|
auto predicate, const array& arr) -> const array& {
|
||||||
|
if (!predicate(arr)) {
|
||||||
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
|
copies.push_back(std::move(arr_copy));
|
||||||
|
return copies.back();
|
||||||
|
} else {
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// We are in vector mode ie single query
|
||||||
|
if (q_pre.shape(2) < 4) {
|
||||||
|
auto q_copy_unless = [](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
// If either the batch or head dimension is a singleton, the other can
|
||||||
|
// be transposed with the sequence dimension
|
||||||
|
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||||
|
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||||
|
(strides[bidx] == shape[3]);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto kv_copy_unless = [](const array& arr) {
|
||||||
|
// keys and values should be copied if:
|
||||||
|
// - the last dimension is not contiguous
|
||||||
|
// - the batch and head dim are not contiguous
|
||||||
|
auto& strides = arr.strides();
|
||||||
|
auto& shape = arr.shape();
|
||||||
|
if (strides.back() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[0] == 1 || shape[1] == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return (strides[0] == strides[1] * shape[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||||
|
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||||
|
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||||
|
|
||||||
|
for (const auto& cp : copies) {
|
||||||
|
encoder.add_temporary(cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Donate the query if possible
|
||||||
|
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||||
|
o.copy_shared_buffer(q);
|
||||||
|
} else {
|
||||||
|
int64_t str_oD = 1;
|
||||||
|
int64_t str_oH = o.shape(3);
|
||||||
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
int64_t str_oB = o.shape(2) * str_oL;
|
||||||
|
size_t data_size = o.shape(0) * str_oB;
|
||||||
|
|
||||||
|
array::Flags flags{
|
||||||
|
/* bool contiguous = */ 1,
|
||||||
|
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||||
|
/* bool col_contiguous = */ 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc(o.nbytes()),
|
||||||
|
data_size,
|
||||||
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
|
flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full attention mode should never reach here
|
||||||
|
else {
|
||||||
|
throw std::runtime_error("Doesn't support matrix yet.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -104,7 +104,7 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Outputs of all kernels in the encoder including temporaries
|
// Outputs of all kernels in the encoder including temporaries
|
||||||
std::unordered_set<const void*> outputs() {
|
std::unordered_set<const void*>& outputs() {
|
||||||
return all_outputs_;
|
return all_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@@ -80,7 +82,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available();
|
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -111,6 +113,8 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
|
} else if (bk == "nccl") {
|
||||||
|
group = nccl::init(strict);
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|||||||
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
8
mlx/distributed/nccl/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
||||||
|
find_package(NCCL REQUIRED)
|
||||||
|
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
||||||
|
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
||||||
|
else()
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
||||||
|
endif()
|
||||||
359
mlx/distributed/nccl/nccl.cpp
Normal file
359
mlx/distributed/nccl/nccl.cpp
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <mutex>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
#define CHECK_CUDA(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"CUDA error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_NCCL(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
fprintf( \
|
||||||
|
stderr, \
|
||||||
|
"NCCL error %s:%d '%s'\n", \
|
||||||
|
__FILE__, \
|
||||||
|
__LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t sent = send(sock, ptr, len, 0);
|
||||||
|
if (sent <= 0) {
|
||||||
|
perror("send");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += sent;
|
||||||
|
len -= sent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void recvAll(int sock, void* buf, size_t len) {
|
||||||
|
char* ptr = reinterpret_cast<char*>(buf);
|
||||||
|
while (len > 0) {
|
||||||
|
ssize_t rec = recv(sock, ptr, len, 0);
|
||||||
|
if (rec <= 0) {
|
||||||
|
perror("recv");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
ptr += rec;
|
||||||
|
len -= rec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void bootstrap_unique_id(
|
||||||
|
ncclUniqueId& id,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
const std::string& initMethod) {
|
||||||
|
// Parse the init method to extract the host and port
|
||||||
|
if (initMethod.rfind("tcp://", 0) != 0)
|
||||||
|
throw;
|
||||||
|
auto hostport = initMethod.substr(6);
|
||||||
|
auto colon = hostport.find(':');
|
||||||
|
std::string host = hostport.substr(0, colon);
|
||||||
|
int port = std::stoi(hostport.substr(colon + 1));
|
||||||
|
|
||||||
|
if (rank == 0) {
|
||||||
|
// create a unique id on the rank 0
|
||||||
|
CHECK_NCCL(ncclGetUniqueId(&id));
|
||||||
|
|
||||||
|
// create a socket to send the unique id to all other ranks
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
int reuse = 1;
|
||||||
|
// Without this, if rank-0 crashes or restarts process quickly,
|
||||||
|
// the OS might refuse to let binding to the same port, so reuse
|
||||||
|
|
||||||
|
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] bind() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
if (listen(sock, size - 1) < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] listen() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int peer = 1; peer < size; ++peer) {
|
||||||
|
int conn = accept(sock, nullptr, nullptr);
|
||||||
|
if (conn < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] accept() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
sendAll(conn, &id, sizeof(id));
|
||||||
|
close(conn);
|
||||||
|
}
|
||||||
|
close(sock);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Here just wanted to make show that rank 0 has enough time to bind
|
||||||
|
// so we will retry to connect until max attempts
|
||||||
|
|
||||||
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sock < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] socket() failed: " << strerror(errno);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
hostent* he = gethostbyname(host.c_str());
|
||||||
|
if (!he) {
|
||||||
|
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
||||||
|
}
|
||||||
|
sockaddr_in serv = {};
|
||||||
|
serv.sin_family = AF_INET;
|
||||||
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
|
const int max_retries = 30;
|
||||||
|
int attempt = 0;
|
||||||
|
bool connected = false;
|
||||||
|
|
||||||
|
for (attempt = 0; attempt < max_retries; ++attempt) {
|
||||||
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
|
0) {
|
||||||
|
connected = true;
|
||||||
|
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
||||||
|
<< attempt + 1 << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (errno != ECONNREFUSED) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!connected) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
||||||
|
<< " retries: " << strerror(errno);
|
||||||
|
close(sock);
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
recvAll(sock, &id, sizeof(id));
|
||||||
|
close(sock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct type_identity {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
switch (arr.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||||
|
case int8:
|
||||||
|
f(type_identity<int8_t>{}, ncclChar);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
f(type_identity<uint8_t>{}, ncclUint8);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
f(type_identity<int32_t>{}, ncclInt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
f(type_identity<uint32_t>{}, ncclUint32);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
f(type_identity<int64_t>{}, ncclInt64);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
f(type_identity<uint64_t>{}, ncclUint64);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
f(type_identity<float16_t>{}, ncclHalf);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
f(type_identity<float>{}, ncclFloat);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
f(type_identity<double>{}, ncclDouble);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
class NCCLGroup : public GroupImpl {
|
||||||
|
public:
|
||||||
|
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
||||||
|
: rank_(worldRank),
|
||||||
|
size_(worldSize),
|
||||||
|
comm_(nullptr),
|
||||||
|
initMethod_(initMethod) {
|
||||||
|
if (initialized_)
|
||||||
|
return;
|
||||||
|
int ndev;
|
||||||
|
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
||||||
|
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
||||||
|
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
||||||
|
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
||||||
|
initialized_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLGroup() {
|
||||||
|
ncclCommDestroy(comm_);
|
||||||
|
ncclGroupEnd();
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() override {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
|
throw std::runtime_error("[nccl] Group split not supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[nccl] All gather not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& output, int src, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_max(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
|
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_impl(
|
||||||
|
const array& input,
|
||||||
|
array& output,
|
||||||
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
|
ncclRedOp_t op) {
|
||||||
|
auto& encoder = cu::get_command_encoder(stream);
|
||||||
|
|
||||||
|
CHECK_NCCL(ncclAllReduce(
|
||||||
|
input.data<T>(),
|
||||||
|
output.data<T>(),
|
||||||
|
input.size(),
|
||||||
|
dt,
|
||||||
|
op,
|
||||||
|
comm_,
|
||||||
|
encoder.stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank_, size_;
|
||||||
|
std::string initMethod_;
|
||||||
|
ncclUniqueId uniqueId_;
|
||||||
|
ncclComm_t comm_;
|
||||||
|
bool initialized_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
static std::string get_env_var_or_throw(const char* env_var_name) {
|
||||||
|
const char* value = std::getenv(env_var_name);
|
||||||
|
if (value == nullptr) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[nccl] Required environment variable '" << env_var_name
|
||||||
|
<< "' is not set. "
|
||||||
|
<< "Please set it before initializing the distributed backend.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
return std::string(value);
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
||||||
|
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
||||||
|
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
||||||
|
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
||||||
|
|
||||||
|
int rank = std::stoi(rank_str);
|
||||||
|
int n_nodes = std::stoi(n_nodes_str);
|
||||||
|
std::string init_method = "tcp://" + host + ":" + port;
|
||||||
|
|
||||||
|
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
12
mlx/distributed/nccl/nccl.h
Normal file
12
mlx/distributed/nccl/nccl.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict = false);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
20
mlx/distributed/nccl/no_nccl.cpp
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/distributed/nccl/nccl.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
|
if (strict) {
|
||||||
|
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::nccl
|
||||||
@@ -31,8 +31,7 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(
|
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||||
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -975,7 +975,6 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
int rank_;
|
int rank_;
|
||||||
int size_;
|
int size_;
|
||||||
|
|
||||||
bool verbose_;
|
bool verbose_;
|
||||||
|
|
||||||
ThreadPool pool_;
|
ThreadPool pool_;
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 27
|
#define MLX_VERSION_MINOR 28
|
||||||
#define MLX_VERSION_PATCH 1
|
#define MLX_VERSION_PATCH 0
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -415,6 +415,45 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = args.nproc_per_node * len(hosts)
|
||||||
|
|
||||||
|
base_env = os.environ.copy()
|
||||||
|
base_env.update(
|
||||||
|
{
|
||||||
|
"NCCL_DEBUG": "INFO",
|
||||||
|
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||||
|
"NCCL_HOST_IP": master_host,
|
||||||
|
"NCCL_PORT": str(master_port),
|
||||||
|
"MLX_WORLD_SIZE": str(world_size),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
procs = []
|
||||||
|
try:
|
||||||
|
for rank in range(world_size):
|
||||||
|
env = base_env.copy()
|
||||||
|
env["MLX_RANK"] = str(rank)
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||||
|
p = Popen(command, env=env)
|
||||||
|
procs.append(p)
|
||||||
|
|
||||||
|
for p in procs:
|
||||||
|
ret = p.wait()
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(f"Rank process exited with {ret}")
|
||||||
|
|
||||||
|
except (RuntimeError, KeyboardInterrupt) as err:
|
||||||
|
for p in procs:
|
||||||
|
if p.poll() is None:
|
||||||
|
try:
|
||||||
|
p.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
def check_ssh_connections(hosts):
|
||||||
results = [False] * len(hosts)
|
results = [False] * len(hosts)
|
||||||
|
|
||||||
@@ -665,7 +704,7 @@ def distributed_config():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
@@ -737,7 +776,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi"],
|
choices=["ring", "mpi", "nccl"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@@ -769,6 +808,19 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc-per-node",
|
||||||
|
type=positive_number,
|
||||||
|
default=1,
|
||||||
|
help="How many processes to run per node (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
if rest[0] == "--":
|
||||||
rest.pop(0)
|
rest.pop(0)
|
||||||
@@ -799,8 +851,10 @@ def main():
|
|||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
launch_ring(parser, hosts, args, rest)
|
launch_ring(parser, hosts, args, rest)
|
||||||
elif args.backend == "mpi":
|
if args.backend == "mpi":
|
||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = dict(tree_flatten(self.parameters()))
|
params_dict = tree_flatten(self.parameters(), destination={})
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
|
stream: mx.Stream = mx.cpu,
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@@ -94,6 +95,7 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
|
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@@ -104,7 +106,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@@ -114,8 +114,11 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
tree: Any,
|
||||||
) -> Any:
|
prefix: str = "",
|
||||||
|
is_leaf: Optional[Callable] = None,
|
||||||
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
||||||
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@@ -128,9 +131,12 @@ def tree_flatten(
|
|||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
print(tree_flatten([[[0]]], ".hello"))
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
|
tree_flatten({"a": {"b": 1}}, destination={})
|
||||||
|
{"a.b": 1}
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
@@ -140,26 +146,50 @@ def tree_flatten(
|
|||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
destination (list or dict, optional): A list or dictionary to store the
|
||||||
|
flattened tree. If None an empty list will be used. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
||||||
|
the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
if destination is None:
|
||||||
|
destination = []
|
||||||
|
|
||||||
if is_leaf is None or not is_leaf(tree):
|
# Create the function to update the destination. We are taking advantage of
|
||||||
if isinstance(tree, (list, tuple)):
|
# the fact that list.extend and dict.update have the same API to simplify
|
||||||
for i, t in enumerate(tree):
|
# the code a bit.
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
if isinstance(destination, list):
|
||||||
return flat_tree
|
_add_to_destination = destination.extend
|
||||||
if isinstance(tree, dict):
|
elif isinstance(destination, dict):
|
||||||
for k, t in tree.items():
|
_add_to_destination = destination.update
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
else:
|
||||||
return flat_tree
|
raise ValueError("Destination should be either a list or a dictionary or None")
|
||||||
|
|
||||||
return [(prefix[1:], tree)]
|
# Leaf identified by is_leaf so add it and return
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# List or tuple so recursively add each subtree
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for i, item in enumerate(tree):
|
||||||
|
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Dictionary so recursively add each subtree
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for key, value in tree.items():
|
||||||
|
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Leaf so add it and return
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
print(d)
|
print(d)
|
||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
|
d = tree_unflatten({"hello.world": 42})
|
||||||
|
print(d)
|
||||||
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
items = tree.items() if isinstance(tree, dict) else tree
|
||||||
return tree[0][1]
|
|
||||||
|
|
||||||
try:
|
# Special case when we have just one element in the tree ie not a tree
|
||||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
if len(items) == 1:
|
||||||
is_list = True
|
key, value = next(iter(items))
|
||||||
except ValueError:
|
if key == "":
|
||||||
is_list = False
|
return value
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = defaultdict(list)
|
children = defaultdict(list)
|
||||||
for key, value in tree:
|
for key, value in items:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# recursively map them to the original container
|
# Assume they are a list and fail to dict if the keys are not all integers
|
||||||
if is_list:
|
try:
|
||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
l.extend([{} for _ in range(i - len(l))])
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
else:
|
except ValueError:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
||||||
available backends are tried and the first one that succeeds
|
available backends are tried and the first one that succeeds
|
||||||
becomes the global group which will be returned in subsequent
|
becomes the global group which will be returned in subsequent
|
||||||
calls. Default: ``any``
|
calls. Default: ``any``
|
||||||
|
|||||||
284
python/tests/nccl_test_distributed.py
Normal file
284
python/tests/nccl_test_distributed.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx_tests
|
||||||
|
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
|
|
||||||
|
|
||||||
|
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
world = mx.distributed.init(strict=True, backend="nccl")
|
||||||
|
rank = world.rank()
|
||||||
|
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||||
|
|
||||||
|
def test_all_reduce(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
(mx.int8, 0),
|
||||||
|
(mx.uint8, 0),
|
||||||
|
(mx.int32, 0),
|
||||||
|
(mx.uint32, 0),
|
||||||
|
(mx.float32, 1e-6),
|
||||||
|
(mx.float16, 5e-3),
|
||||||
|
(mx.bfloat16, 1e-1),
|
||||||
|
]
|
||||||
|
sizes = [
|
||||||
|
(7,),
|
||||||
|
(10,),
|
||||||
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
|
||||||
|
for dt, rtol in dtypes:
|
||||||
|
for sh in sizes:
|
||||||
|
x = (
|
||||||
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
|
||||||
|
# All sum
|
||||||
|
y = mx.distributed.all_sum(x[world.rank()])
|
||||||
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
def test_average_gradients(self):
|
||||||
|
original_all_sum = mx.distributed.all_sum
|
||||||
|
n_calls = 0
|
||||||
|
xtype = None
|
||||||
|
|
||||||
|
def new_all_sum(x, **kwargs):
|
||||||
|
nonlocal n_calls
|
||||||
|
nonlocal xtype
|
||||||
|
|
||||||
|
n_calls += 1
|
||||||
|
if xtype is not None:
|
||||||
|
self.assertEqual(xtype, x.dtype)
|
||||||
|
|
||||||
|
return original_all_sum(x, **kwargs)
|
||||||
|
|
||||||
|
mx.distributed.all_sum = new_all_sum
|
||||||
|
try:
|
||||||
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
|
new_grads = average_gradients(grads, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 10)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
xtype = mx.float16
|
||||||
|
new_grads = average_gradients(
|
||||||
|
grads,
|
||||||
|
all_reduce_size=2 * 50,
|
||||||
|
communication_type=mx.float16,
|
||||||
|
stream=mx.gpu,
|
||||||
|
)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.distributed.all_sum(x)
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_only = mx.get_peak_memory()
|
||||||
|
y = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
def test_shard_linear(self):
|
||||||
|
# Seed the prng to have the same inputs and weights generated everywhere
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
world = mx.distributed.init()
|
||||||
|
part = (
|
||||||
|
slice(None),
|
||||||
|
slice(
|
||||||
|
world.rank() * 1024 // world.size(),
|
||||||
|
(world.rank() + 1) * 1024 // world.size(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
x = mx.random.normal((4, 1024))
|
||||||
|
|
||||||
|
# Create and shard some linear layers
|
||||||
|
lin = nn.Linear(1024, 1024, bias=True)
|
||||||
|
slin1 = shard_linear(lin, "all-to-sharded")
|
||||||
|
slin2 = shard_linear(lin, "sharded-to-all")
|
||||||
|
y = lin(x)
|
||||||
|
y1 = slin1(x)
|
||||||
|
y2 = slin2(x[part])
|
||||||
|
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
# Check the backward works as expected
|
||||||
|
def dummy_loss(model, x, y):
|
||||||
|
return (model(x) * y).sum()
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||||
|
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||||
|
)
|
||||||
|
|
||||||
|
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||||
|
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 128))
|
||||||
|
y = mx.random.normal((4, 128))
|
||||||
|
|
||||||
|
l1, g1 = grad1(mod, x, y)
|
||||||
|
l2, g2 = grad2(smod, x, y)
|
||||||
|
mx.eval(l1, g1, l2, g2)
|
||||||
|
|
||||||
|
part = slice(
|
||||||
|
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(l1, l2))
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["weight"][part],
|
||||||
|
g2["layers"][0]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["weight"][part],
|
||||||
|
g2["layers"][2]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["weight"][:, part],
|
||||||
|
g2["layers"][1]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["weight"][:, part],
|
||||||
|
g2["layers"][3]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["bias"][part],
|
||||||
|
g2["layers"][0]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["bias"][part],
|
||||||
|
g2["layers"][2]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_shard_predicate(self):
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
class MyConv(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate = kwargs.pop("aggregate", False)
|
||||||
|
self.conv = nn.Conv2d(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.aggregate:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sharding(path, weight):
|
||||||
|
parts = path.split(".")
|
||||||
|
even = int(parts[1]) % 2 == 0
|
||||||
|
if even:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return -1 if parts[-1] != "bias" else None
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||||
|
)
|
||||||
|
smod.update(mod.parameters())
|
||||||
|
shard_inplace(smod, sharding)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 16, 16, 3))
|
||||||
|
y1 = mod(x)
|
||||||
|
y2 = smod(x)
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mlx_tests.MLXTestRunner()
|
||||||
@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
|||||||
Reference in New Issue
Block a user