mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
984cefb14d
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f | ||
|
|
7f8ba2a003 | ||
|
|
c28249b81a | ||
|
|
e74bcdc5e3 | ||
|
|
d8ed6c1aa3 |
@@ -1,54 +0,0 @@
|
|||||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
|
||||||
# directories.
|
|
||||||
|
|
||||||
set(NCCL_ROOT_DIR
|
|
||||||
$ENV{NCCL_ROOT_DIR}
|
|
||||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
|
||||||
|
|
||||||
find_path(
|
|
||||||
NCCL_INCLUDE_DIRS
|
|
||||||
NAMES nccl.h
|
|
||||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
|
||||||
|
|
||||||
if($ENV{USE_STATIC_NCCL})
|
|
||||||
message(
|
|
||||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
|
||||||
set(NCCL_LIBNAME "libnccl_static.a")
|
|
||||||
else()
|
|
||||||
set(NCCL_LIBNAME "nccl")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_library(
|
|
||||||
NCCL_LIBRARIES
|
|
||||||
NAMES ${NCCL_LIBNAME}
|
|
||||||
HINTS ${NCCL_LIB_DIR}
|
|
||||||
${NCCL_ROOT_DIR}
|
|
||||||
${NCCL_ROOT_DIR}/lib
|
|
||||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
|
||||||
${NCCL_ROOT_DIR}/lib64
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
|
||||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|
||||||
|
|
||||||
include(FindPackageHandleStandardArgs)
|
|
||||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
|
||||||
NCCL_LIBRARIES)
|
|
||||||
|
|
||||||
if(NCCL_FOUND)
|
|
||||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
|
||||||
message(
|
|
||||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
|
||||||
file(
|
|
||||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
|
||||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
|
||||||
LIMIT_COUNT 1)
|
|
||||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
|
||||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
|
||||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
|
||||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
|
||||||
endif()
|
|
||||||
message(
|
|
||||||
STATUS
|
|
||||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
|
||||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
|
||||||
endif()
|
|
||||||
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
|||||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
apt-get update -y
|
apt-get update -y
|
||||||
apt-get -y install cuda-toolkit-12-9
|
apt-get -y install cuda-toolkit-12-9
|
||||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state, destination={})
|
state = tree_flatten(optimizer.state)
|
||||||
mx.save_safetensors("optimizer.safetensors", state)
|
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -7,17 +7,17 @@ Exporting Functions
|
|||||||
|
|
||||||
MLX has an API to export and import functions to and from a file. This lets you
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
front-end (e.g. C++).
|
front-end (e.g. C++).
|
||||||
|
|
||||||
This guide walks through the basics of the MLX export API with some examples.
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
To see the full list of functions check-out the :ref:`API documentation
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
<export>`.
|
<export>`.
|
||||||
|
|
||||||
Basics of Exporting
|
Basics of Exporting
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
Let's start with a simple example:
|
Let's start with a simple example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
|||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
|
||||||
# Both arguments to fun are positional
|
# Both arguments to fun are positional
|
||||||
mx.export_function("add.mlxfn", fun, x, y)
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
|||||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
they are evaluated. The computation graph that gets exported will include
|
they are evaluated. The computation graph that gets exported will include
|
||||||
the computation that produces enclosed inputs.
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
exported function would include the random initialization of the
|
exported function would include the random initialization of the
|
||||||
:obj:`mlx.nn.Module` parameters.
|
:obj:`mlx.nn.Module` parameters.
|
||||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
# Set the model's parameters to the input parameters
|
# Set the model's parameters to the input parameters
|
||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = tree_flatten(model.parameters(), destination={})
|
params = dict(tree_flatten(model.parameters()))
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array(-1.0))
|
out, = imported_abs(mx.array(-1.0))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
With ``shapeless=False`` (which is the default), the second call to
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
def fun(x, y=None):
|
def fun(x, y=None):
|
||||||
constant = mx.array(3.0)
|
constant = mx.array(3.0)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x += y
|
x += y
|
||||||
return x + constant
|
return x + constant
|
||||||
|
|
||||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
In the above example the function constant data, (i.e. ``constant``), is only
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
saved once.
|
saved once.
|
||||||
|
|
||||||
Transformations with Imported Functions
|
Transformations with Imported Functions
|
||||||
---------------------------------------
|
---------------------------------------
|
||||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
|||||||
# Prints: array(1, dtype=float32)
|
# Prints: array(1, dtype=float32)
|
||||||
print(dfdx(x))
|
print(dfdx(x))
|
||||||
|
|
||||||
# Compile the imported function
|
# Compile the imported function
|
||||||
mx.compile(imported_fun)
|
mx.compile(imported_fun)
|
||||||
# Prints: array(0, dtype=float32)
|
# Prints: array(0, dtype=float32)
|
||||||
print(compiled_fun(x)[0])
|
print(compiled_fun(x)[0])
|
||||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
|||||||
// Prints: array(2, dtype=float32)
|
// Prints: array(2, dtype=float32)
|
||||||
std::cout << outputs[0] << std::endl;
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
Imported functions can be transformed in C++ just like in Python. Use
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/distributed/primitives.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
namespace distributed {
|
|
||||||
void AllReduce::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
assert(outputs.size() == 1);
|
|
||||||
|
|
||||||
auto& input = inputs[0];
|
|
||||||
auto& output = outputs[0];
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
|
||||||
|
|
||||||
if (input.is_donatable()) {
|
|
||||||
output.copy_shared_buffer(input);
|
|
||||||
} else {
|
|
||||||
output.set_data(allocator::malloc(output.nbytes()));
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.set_input_array(input);
|
|
||||||
encoder.set_output_array(output);
|
|
||||||
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
auto& s = stream();
|
|
||||||
|
|
||||||
switch (reduce_type_) {
|
|
||||||
case Sum:
|
|
||||||
distributed::detail::all_sum(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
case Max:
|
|
||||||
distributed::detail::all_max(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
case Min:
|
|
||||||
distributed::detail::all_min(group(), input, output, s);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Only all reduce sum, max, and min are supported.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace distributed
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -46,6 +46,7 @@ NO_GPU_MULTI(CustomKernel)
|
|||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
|
NO_GPU_MULTI(AllReduce)
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
NO_GPU_MULTI(Recv)
|
NO_GPU_MULTI(Recv)
|
||||||
|
|||||||
@@ -8,13 +8,19 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/transforms_impl.h"
|
|
||||||
|
|
||||||
|
// cudnn_frontend.h redefines this macro.
|
||||||
|
#undef CHECK_CUDA_ERROR
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
|
namespace fe = cudnn_frontend;
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
@@ -639,6 +645,294 @@ void sdpa_vector_fallback(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SDPACacheKey {
|
||||||
|
int device_id;
|
||||||
|
fe::DataType_t cudnn_type;
|
||||||
|
|
||||||
|
int B;
|
||||||
|
int H;
|
||||||
|
int D;
|
||||||
|
|
||||||
|
int qL;
|
||||||
|
int kL;
|
||||||
|
|
||||||
|
int gqa_factor;
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
int64_t Q_strides[3];
|
||||||
|
int64_t K_strides[3];
|
||||||
|
int64_t V_strides[3];
|
||||||
|
int64_t O_strides[3];
|
||||||
|
|
||||||
|
bool generate_stats;
|
||||||
|
bool causal_mask;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& sdpa_cache() {
|
||||||
|
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
|
||||||
|
cache(
|
||||||
|
/* capacity */ 128);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define Q_UID 1
|
||||||
|
#define K_UID 2
|
||||||
|
#define V_UID 3
|
||||||
|
#define O_UID 4
|
||||||
|
#define STATS_UID 5
|
||||||
|
|
||||||
|
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const SDPACacheKey& cache_key) {
|
||||||
|
// Check if graph has already been fully built
|
||||||
|
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up new graph
|
||||||
|
auto graph = std::make_shared<fe::graph::Graph>();
|
||||||
|
|
||||||
|
graph->set_io_data_type(cache_key.cudnn_type)
|
||||||
|
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||||
|
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||||
|
|
||||||
|
auto Q = graph->tensor(
|
||||||
|
fe::graph::Tensor_attributes()
|
||||||
|
.set_name("Q")
|
||||||
|
.set_uid(Q_UID)
|
||||||
|
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
||||||
|
.set_stride(
|
||||||
|
{cache_key.Q_strides[0],
|
||||||
|
cache_key.Q_strides[1],
|
||||||
|
cache_key.Q_strides[2],
|
||||||
|
1}));
|
||||||
|
|
||||||
|
int h_kv = cache_key.H / cache_key.gqa_factor;
|
||||||
|
auto K =
|
||||||
|
graph->tensor(fe::graph::Tensor_attributes()
|
||||||
|
.set_name("K")
|
||||||
|
.set_uid(K_UID)
|
||||||
|
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
||||||
|
.set_stride(
|
||||||
|
{cache_key.K_strides[0],
|
||||||
|
cache_key.K_strides[1],
|
||||||
|
cache_key.V_strides[2],
|
||||||
|
1}));
|
||||||
|
|
||||||
|
auto V =
|
||||||
|
graph->tensor(fe::graph::Tensor_attributes()
|
||||||
|
.set_name("V")
|
||||||
|
.set_uid(V_UID)
|
||||||
|
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
||||||
|
.set_stride(
|
||||||
|
{cache_key.V_strides[0],
|
||||||
|
cache_key.V_strides[1],
|
||||||
|
cache_key.V_strides[2],
|
||||||
|
1}));
|
||||||
|
|
||||||
|
auto sdpa_options = fe::graph::SDPA_attributes()
|
||||||
|
.set_name("flash_attention")
|
||||||
|
.set_is_inference(!cache_key.generate_stats)
|
||||||
|
.set_attn_scale(cache_key.scale);
|
||||||
|
|
||||||
|
if (cache_key.causal_mask && cache_key.qL > 1) {
|
||||||
|
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
|
||||||
|
.set_diagonal_band_right_bound(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
|
||||||
|
|
||||||
|
O->set_output(true)
|
||||||
|
.set_uid(O_UID)
|
||||||
|
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
||||||
|
.set_stride(
|
||||||
|
{cache_key.O_strides[0],
|
||||||
|
cache_key.O_strides[1],
|
||||||
|
cache_key.O_strides[2],
|
||||||
|
1});
|
||||||
|
|
||||||
|
if (cache_key.generate_stats) {
|
||||||
|
Stats->set_output(true)
|
||||||
|
.set_data_type(fe::DataType_t::FLOAT)
|
||||||
|
.set_uid(STATS_UID);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build and Validate cudnn graph
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
|
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||||
|
if (cudnnGetVersion() < 90600) {
|
||||||
|
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
||||||
|
if (!build_status.is_good()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unable to build cudnn graph for attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
build_status.get_message());
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto val_status = graph->validate();
|
||||||
|
auto op_status = graph->build_operation_graph(handle);
|
||||||
|
|
||||||
|
auto plan_stauts =
|
||||||
|
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
|
||||||
|
if (!plan_stauts.is_good()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unable to create exec plan for cudnn attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
plan_stauts.get_message());
|
||||||
|
}
|
||||||
|
|
||||||
|
graph->select_behavior_notes(
|
||||||
|
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
|
|
||||||
|
auto support_status = graph->check_support(handle);
|
||||||
|
if (!support_status.is_good()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"No cuda graph support for cudnn attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
support_status.get_message());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto build_status = graph->build_plans(handle);
|
||||||
|
if (!build_status.is_good()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unable to build cudnn graph for attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
build_status.get_message());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
|
||||||
|
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return fe::DataType_t::INT8;
|
||||||
|
case int32:
|
||||||
|
return fe::DataType_t::INT32;
|
||||||
|
case uint8:
|
||||||
|
return fe::DataType_t::UINT8;
|
||||||
|
case float16:
|
||||||
|
return fe::DataType_t::HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return fe::DataType_t::BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return fe::DataType_t::FLOAT;
|
||||||
|
case float64:
|
||||||
|
return fe::DataType_t::DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_cudnn(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
|
||||||
|
|
||||||
|
int B = q.shape(0);
|
||||||
|
int H = q.shape(1);
|
||||||
|
int D = q.shape(3);
|
||||||
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
|
|
||||||
|
int qL = q.shape(2);
|
||||||
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
SDPACacheKey cache_key{
|
||||||
|
/* int device_id = */ encoder.device().cuda_device(),
|
||||||
|
/* fe::DataType_t cudnn_type = */ cudnn_type,
|
||||||
|
|
||||||
|
/* int B = */ B,
|
||||||
|
/* int H = */ H,
|
||||||
|
/* int D = */ D,
|
||||||
|
|
||||||
|
/* int qL = */ qL,
|
||||||
|
/* int kL = */ kL,
|
||||||
|
|
||||||
|
/* int gqa_factor = */ gqa_factor,
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
|
||||||
|
|
||||||
|
/* bool generate_stats = */ false,
|
||||||
|
/* bool causal_mask = */ do_causal_};
|
||||||
|
|
||||||
|
auto graph = get_sdpa_forward_graph(encoder, cache_key);
|
||||||
|
|
||||||
|
int64_t workspace_size = 0;
|
||||||
|
auto workspace_status = graph->get_workspace_size(workspace_size);
|
||||||
|
if (!workspace_status.is_good()) {
|
||||||
|
throw std::runtime_error("Unable to get workspace for cudnn attention.");
|
||||||
|
}
|
||||||
|
|
||||||
|
array workspace(
|
||||||
|
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
|
||||||
|
auto workspace_ptr = workspace.data<void>();
|
||||||
|
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack = {
|
||||||
|
{Q_UID, const_cast<void*>(q.data<void>())},
|
||||||
|
{K_UID, const_cast<void*>(k.data<void>())},
|
||||||
|
{V_UID, const_cast<void*>(v.data<void>())},
|
||||||
|
{O_UID, o.data<void>()}};
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||||
|
if (cudnnGetVersion() < 90600) {
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
|
||||||
|
|
||||||
|
if (!exec_status.is_good()) {
|
||||||
|
capture.discard = true;
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unable to execute cudnn attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
exec_status.get_message());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cudaGraph_t cu_graph;
|
||||||
|
cudaGraphCreate(&cu_graph, 0);
|
||||||
|
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
|
|
||||||
|
auto cu_graph_status = graph->populate_cuda_graph(
|
||||||
|
handle, variant_pack, workspace_ptr, cu_graph);
|
||||||
|
|
||||||
|
if (!cu_graph_status.is_good()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unable to add cuda graph for cudnn attention."
|
||||||
|
" Failed with message: " +
|
||||||
|
cu_graph_status.get_message());
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.add_graph_node(cu_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
@@ -651,9 +945,6 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
bool has_arr_mask,
|
bool has_arr_mask,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
if (detail::in_grad_tracing()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (s.device == Device::cpu) {
|
if (s.device == Device::cpu) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -669,7 +960,15 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
const bool supported_vector_config =
|
const bool supported_vector_config =
|
||||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
const bool supported_config = supported_vector_config;
|
auto& cu_device = cu::device(s.device);
|
||||||
|
|
||||||
|
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||||
|
cu_device.compute_capability_major() >= 8 &&
|
||||||
|
query_sequence_length == key_sequence_length &&
|
||||||
|
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||||
|
|
||||||
|
const bool supported_config =
|
||||||
|
(supported_matrix_config || supported_vector_config);
|
||||||
|
|
||||||
return has_arr_mask || !supported_config;
|
return has_arr_mask || !supported_config;
|
||||||
}
|
}
|
||||||
@@ -703,6 +1002,10 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto is_matrix_contiguous = [](const array& arr) {
|
||||||
|
return arr.strides(-1) == 1;
|
||||||
|
};
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) < 4) {
|
if (q_pre.shape(2) < 4) {
|
||||||
auto q_copy_unless = [](const array& arr) {
|
auto q_copy_unless = [](const array& arr) {
|
||||||
@@ -756,7 +1059,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
array::Flags flags{
|
array::Flags flags{
|
||||||
/* bool contiguous = */ 1,
|
/* bool contiguous = */ 1,
|
||||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
/* bool row_contiguous = */ 0,
|
||||||
/* bool col_contiguous = */ 0,
|
/* bool col_contiguous = */ 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -770,9 +1073,35 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode should never reach here
|
// Full attention mode
|
||||||
else {
|
else {
|
||||||
throw std::runtime_error("Doesn't support matrix yet.");
|
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||||
|
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||||
|
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||||
|
|
||||||
|
for (const auto& cp : copies) {
|
||||||
|
encoder.add_temporary(cp);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t str_oD = 1;
|
||||||
|
int64_t str_oH = o.shape(3);
|
||||||
|
int64_t str_oL = o.shape(1) * str_oH;
|
||||||
|
int64_t str_oB = o.shape(2) * str_oL;
|
||||||
|
size_t data_size = o.shape(0) * str_oB;
|
||||||
|
|
||||||
|
array::Flags flags{
|
||||||
|
/* bool contiguous = */ 1,
|
||||||
|
/* bool row_contiguous = */ 0,
|
||||||
|
/* bool col_contiguous = */ 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
o.set_data(
|
||||||
|
allocator::malloc(o.nbytes()),
|
||||||
|
data_size,
|
||||||
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
|
flags);
|
||||||
|
|
||||||
|
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Outputs of all kernels in the encoder including temporaries
|
// Outputs of all kernels in the encoder including temporaries
|
||||||
std::unordered_set<const void*>& outputs() {
|
std::unordered_set<const void*> outputs() {
|
||||||
return all_outputs_;
|
return all_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -6,4 +6,3 @@ target_sources(
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
|
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/mpi/mpi.h"
|
#include "mlx/distributed/mpi/mpi.h"
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
|
||||||
#include "mlx/distributed/ring/ring.h"
|
#include "mlx/distributed/ring/ring.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
@@ -82,7 +80,7 @@ class EmptyGroup : public GroupImpl {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi::is_available() || ring::is_available() || nccl::is_available();
|
return mpi::is_available() || ring::is_available();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Group::rank() const {
|
int Group::rank() const {
|
||||||
@@ -113,8 +111,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
|
|||||||
group = mpi::init(strict);
|
group = mpi::init(strict);
|
||||||
} else if (bk == "ring") {
|
} else if (bk == "ring") {
|
||||||
group = ring::init(strict);
|
group = ring::init(strict);
|
||||||
} else if (bk == "nccl") {
|
|
||||||
group = nccl::init(strict);
|
|
||||||
} else if (bk == "any") {
|
} else if (bk == "any") {
|
||||||
group = ring::init(false);
|
group = ring::init(false);
|
||||||
bk_ = "ring";
|
bk_ = "ring";
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
if(MLX_BUILD_CUDA)
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
|
|
||||||
find_package(NCCL REQUIRED)
|
|
||||||
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
|
|
||||||
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
|
|
||||||
else()
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
|
|
||||||
endif()
|
|
||||||
@@ -1,359 +0,0 @@
|
|||||||
#include <arpa/inet.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <nccl.h>
|
|
||||||
#include <netdb.h>
|
|
||||||
#include <sys/socket.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <iostream>
|
|
||||||
#include <mutex>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/distributed/distributed.h"
|
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
#define CHECK_CUDA(cmd) \
|
|
||||||
do { \
|
|
||||||
cudaError_t e = cmd; \
|
|
||||||
if (e != cudaSuccess) { \
|
|
||||||
fprintf( \
|
|
||||||
stderr, \
|
|
||||||
"CUDA error %s:%d '%s'\n", \
|
|
||||||
__FILE__, \
|
|
||||||
__LINE__, \
|
|
||||||
cudaGetErrorString(e)); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define CHECK_NCCL(cmd) \
|
|
||||||
do { \
|
|
||||||
ncclResult_t r = cmd; \
|
|
||||||
if (r != ncclSuccess) { \
|
|
||||||
fprintf( \
|
|
||||||
stderr, \
|
|
||||||
"NCCL error %s:%d '%s'\n", \
|
|
||||||
__FILE__, \
|
|
||||||
__LINE__, \
|
|
||||||
ncclGetErrorString(r)); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
|
||||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t sent = send(sock, ptr, len, 0);
|
|
||||||
if (sent <= 0) {
|
|
||||||
perror("send");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
ptr += sent;
|
|
||||||
len -= sent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void recvAll(int sock, void* buf, size_t len) {
|
|
||||||
char* ptr = reinterpret_cast<char*>(buf);
|
|
||||||
while (len > 0) {
|
|
||||||
ssize_t rec = recv(sock, ptr, len, 0);
|
|
||||||
if (rec <= 0) {
|
|
||||||
perror("recv");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
ptr += rec;
|
|
||||||
len -= rec;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void bootstrap_unique_id(
|
|
||||||
ncclUniqueId& id,
|
|
||||||
int rank,
|
|
||||||
int size,
|
|
||||||
const std::string& initMethod) {
|
|
||||||
// Parse the init method to extract the host and port
|
|
||||||
if (initMethod.rfind("tcp://", 0) != 0)
|
|
||||||
throw;
|
|
||||||
auto hostport = initMethod.substr(6);
|
|
||||||
auto colon = hostport.find(':');
|
|
||||||
std::string host = hostport.substr(0, colon);
|
|
||||||
int port = std::stoi(hostport.substr(colon + 1));
|
|
||||||
|
|
||||||
if (rank == 0) {
|
|
||||||
// create a unique id on the rank 0
|
|
||||||
CHECK_NCCL(ncclGetUniqueId(&id));
|
|
||||||
|
|
||||||
// create a socket to send the unique id to all other ranks
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
sockaddr_in serv = {};
|
|
||||||
serv.sin_family = AF_INET;
|
|
||||||
serv.sin_addr.s_addr = htonl(INADDR_ANY);
|
|
||||||
serv.sin_port = htons(port);
|
|
||||||
|
|
||||||
int reuse = 1;
|
|
||||||
// Without this, if rank-0 crashes or restarts process quickly,
|
|
||||||
// the OS might refuse to let binding to the same port, so reuse
|
|
||||||
|
|
||||||
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] setsockopt() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] bind() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
if (listen(sock, size - 1) < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] listen() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int peer = 1; peer < size; ++peer) {
|
|
||||||
int conn = accept(sock, nullptr, nullptr);
|
|
||||||
if (conn < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] accept() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
sendAll(conn, &id, sizeof(id));
|
|
||||||
close(conn);
|
|
||||||
}
|
|
||||||
close(sock);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// Here just wanted to make show that rank 0 has enough time to bind
|
|
||||||
// so we will retry to connect until max attempts
|
|
||||||
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
||||||
if (sock < 0) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] socket() failed: " << strerror(errno);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
hostent* he = gethostbyname(host.c_str());
|
|
||||||
if (!he) {
|
|
||||||
throw std::runtime_error("[nccl] lookup failed for host: " + host);
|
|
||||||
}
|
|
||||||
sockaddr_in serv = {};
|
|
||||||
serv.sin_family = AF_INET;
|
|
||||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
|
||||||
serv.sin_port = htons(port);
|
|
||||||
|
|
||||||
const int max_retries = 30;
|
|
||||||
int attempt = 0;
|
|
||||||
bool connected = false;
|
|
||||||
|
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
|
||||||
0) {
|
|
||||||
connected = true;
|
|
||||||
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
|
|
||||||
<< attempt + 1 << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (errno != ECONNREFUSED) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!connected) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
|
||||||
<< " retries: " << strerror(errno);
|
|
||||||
close(sock);
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
recvAll(sock, &id, sizeof(id));
|
|
||||||
close(sock);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct type_identity {
|
|
||||||
using type = T;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_dtype(const array& arr, F&& f) {
|
|
||||||
switch (arr.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
|
||||||
case int8:
|
|
||||||
f(type_identity<int8_t>{}, ncclChar);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
f(type_identity<uint8_t>{}, ncclUint8);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
f(type_identity<int32_t>{}, ncclInt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
f(type_identity<uint32_t>{}, ncclUint32);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
f(type_identity<int64_t>{}, ncclInt64);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
f(type_identity<uint64_t>{}, ncclUint64);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
f(type_identity<float16_t>{}, ncclHalf);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
f(type_identity<float>{}, ncclFloat);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
f(type_identity<double>{}, ncclDouble);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
class NCCLGroup : public GroupImpl {
|
|
||||||
public:
|
|
||||||
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
|
|
||||||
: rank_(worldRank),
|
|
||||||
size_(worldSize),
|
|
||||||
comm_(nullptr),
|
|
||||||
initMethod_(initMethod) {
|
|
||||||
if (initialized_)
|
|
||||||
return;
|
|
||||||
int ndev;
|
|
||||||
CHECK_CUDA(cudaGetDeviceCount(&ndev));
|
|
||||||
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
|
|
||||||
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
|
|
||||||
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
|
|
||||||
initialized_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
~NCCLGroup() {
|
|
||||||
ncclCommDestroy(comm_);
|
|
||||||
ncclGroupEnd();
|
|
||||||
initialized_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank() override {
|
|
||||||
return rank_;
|
|
||||||
}
|
|
||||||
|
|
||||||
int size() override {
|
|
||||||
return size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_sum(const array& input, array& output, Stream stream) override {
|
|
||||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
|
||||||
using T = typename decltype(type_tag)::type;
|
|
||||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
|
||||||
throw std::runtime_error("[nccl] Group split not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[nccl] All gather not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void recv(array& output, int src, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_max(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
|
||||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void all_reduce_impl(
|
|
||||||
const array& input,
|
|
||||||
array& output,
|
|
||||||
Stream stream,
|
|
||||||
ncclDataType_t dt,
|
|
||||||
ncclRedOp_t op) {
|
|
||||||
auto& encoder = cu::get_command_encoder(stream);
|
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
|
||||||
input.data<T>(),
|
|
||||||
output.data<T>(),
|
|
||||||
input.size(),
|
|
||||||
dt,
|
|
||||||
op,
|
|
||||||
comm_,
|
|
||||||
encoder.stream()));
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank_, size_;
|
|
||||||
std::string initMethod_;
|
|
||||||
ncclUniqueId uniqueId_;
|
|
||||||
ncclComm_t comm_;
|
|
||||||
bool initialized_ = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
static std::string get_env_var_or_throw(const char* env_var_name) {
|
|
||||||
const char* value = std::getenv(env_var_name);
|
|
||||||
if (value == nullptr) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[nccl] Required environment variable '" << env_var_name
|
|
||||||
<< "' is not set. "
|
|
||||||
<< "Please set it before initializing the distributed backend.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
return std::string(value);
|
|
||||||
}
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|
||||||
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
|
|
||||||
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
|
|
||||||
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
|
|
||||||
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
|
|
||||||
|
|
||||||
int rank = std::stoi(rank_str);
|
|
||||||
int n_nodes = std::stoi(n_nodes_str);
|
|
||||||
std::string init_method = "tcp://" + host + ":" + port;
|
|
||||||
|
|
||||||
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
|
|
||||||
}
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
|
|
||||||
bool is_available();
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict = false);
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/distributed/nccl/nccl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
|
||||||
|
|
||||||
bool is_available() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
|
||||||
if (strict) {
|
|
||||||
throw std::runtime_error("Cannot initialize nccl distributed backend.");
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::nccl
|
|
||||||
@@ -31,7 +31,8 @@ array all_sum(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
std::make_shared<AllReduce>(
|
||||||
|
to_stream(s, Device::cpu), group, AllReduce::Sum),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -975,6 +975,7 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
int rank_;
|
int rank_;
|
||||||
int size_;
|
int size_;
|
||||||
|
|
||||||
bool verbose_;
|
bool verbose_;
|
||||||
|
|
||||||
ThreadPool pool_;
|
ThreadPool pool_;
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 28
|
#define MLX_VERSION_MINOR 27
|
||||||
#define MLX_VERSION_PATCH 0
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -415,45 +415,6 @@ def launch_mpi(parser, hosts, args, command):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def launch_nccl(parser, hosts, args, command):
|
|
||||||
master_host = hosts[0].ips[0]
|
|
||||||
master_port = args.nccl_port
|
|
||||||
world_size = args.nproc_per_node * len(hosts)
|
|
||||||
|
|
||||||
base_env = os.environ.copy()
|
|
||||||
base_env.update(
|
|
||||||
{
|
|
||||||
"NCCL_DEBUG": "INFO",
|
|
||||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
|
||||||
"NCCL_HOST_IP": master_host,
|
|
||||||
"NCCL_PORT": str(master_port),
|
|
||||||
"MLX_WORLD_SIZE": str(world_size),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
procs = []
|
|
||||||
try:
|
|
||||||
for rank in range(world_size):
|
|
||||||
env = base_env.copy()
|
|
||||||
env["MLX_RANK"] = str(rank)
|
|
||||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
|
||||||
p = Popen(command, env=env)
|
|
||||||
procs.append(p)
|
|
||||||
|
|
||||||
for p in procs:
|
|
||||||
ret = p.wait()
|
|
||||||
if ret != 0:
|
|
||||||
raise RuntimeError(f"Rank process exited with {ret}")
|
|
||||||
|
|
||||||
except (RuntimeError, KeyboardInterrupt) as err:
|
|
||||||
for p in procs:
|
|
||||||
if p.poll() is None:
|
|
||||||
try:
|
|
||||||
p.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def check_ssh_connections(hosts):
|
def check_ssh_connections(hosts):
|
||||||
results = [False] * len(hosts)
|
results = [False] * len(hosts)
|
||||||
|
|
||||||
@@ -704,7 +665,7 @@ def distributed_config():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
@@ -776,7 +737,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi"],
|
||||||
default="ring",
|
default="ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@@ -808,19 +769,6 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--nccl-port",
|
|
||||||
type=int,
|
|
||||||
default=12345,
|
|
||||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--nproc-per-node",
|
|
||||||
type=positive_number,
|
|
||||||
default=1,
|
|
||||||
help="How many processes to run per node (only for nccl backend)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
if rest[0] == "--":
|
if rest[0] == "--":
|
||||||
rest.pop(0)
|
rest.pop(0)
|
||||||
@@ -851,10 +799,8 @@ def main():
|
|||||||
# Launch
|
# Launch
|
||||||
if args.backend == "ring":
|
if args.backend == "ring":
|
||||||
launch_ring(parser, hosts, args, rest)
|
launch_ring(parser, hosts, args, rest)
|
||||||
if args.backend == "mpi":
|
elif args.backend == "mpi":
|
||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
if args.backend == "nccl":
|
|
||||||
launch_nccl(parser, hosts, args, rest)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = tree_flatten(self.parameters(), destination={})
|
curr_weights = dict(tree_flatten(self.parameters()))
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = tree_flatten(self.parameters(), destination={})
|
params_dict = dict(tree_flatten(self.parameters()))
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ def average_gradients(
|
|||||||
group: Optional[mx.distributed.Group] = None,
|
group: Optional[mx.distributed.Group] = None,
|
||||||
all_reduce_size: int = 32 * 1024**2,
|
all_reduce_size: int = 32 * 1024**2,
|
||||||
communication_type: Optional[mx.Dtype] = None,
|
communication_type: Optional[mx.Dtype] = None,
|
||||||
stream: mx.Stream = mx.cpu,
|
|
||||||
):
|
):
|
||||||
"""Average the gradients across the distributed processes in the passed group.
|
"""Average the gradients across the distributed processes in the passed group.
|
||||||
|
|
||||||
@@ -95,7 +94,6 @@ def average_gradients(
|
|||||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||||
type before performing the communication. Typically cast to a
|
type before performing the communication. Typically cast to a
|
||||||
smaller float to reduce the communication size. Default: ``None``.
|
smaller float to reduce the communication size. Default: ``None``.
|
||||||
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
|
||||||
"""
|
"""
|
||||||
group = group or mx.distributed.init()
|
group = group or mx.distributed.init()
|
||||||
N = group.size()
|
N = group.size()
|
||||||
@@ -106,7 +104,7 @@ def average_gradients(
|
|||||||
def _average(x):
|
def _average(x):
|
||||||
dt = x.dtype
|
dt = x.dtype
|
||||||
x = x.astype(communication_type) if communication_type is not None else x
|
x = x.astype(communication_type) if communication_type is not None else x
|
||||||
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||||
|
|
||||||
if all_reduce_size <= 0:
|
if all_reduce_size <= 0:
|
||||||
return tree_map(_average, gradients)
|
return tree_map(_average, gradients)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@@ -114,11 +114,8 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any,
|
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
||||||
prefix: str = "",
|
) -> Any:
|
||||||
is_leaf: Optional[Callable] = None,
|
|
||||||
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
|
||||||
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@@ -131,12 +128,9 @@ def tree_flatten(
|
|||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
print(tree_flatten([[[0]]], prefix=".hello"))
|
print(tree_flatten([[[0]]], ".hello"))
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
tree_flatten({"a": {"b": 1}}, destination={})
|
|
||||||
{"a.b": 1}
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
@@ -146,50 +140,26 @@ def tree_flatten(
|
|||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
destination (list or dict, optional): A list or dictionary to store the
|
|
||||||
flattened tree. If None an empty list will be used. Default: ``None``.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
||||||
the Python tree.
|
|
||||||
"""
|
"""
|
||||||
if destination is None:
|
flat_tree = []
|
||||||
destination = []
|
|
||||||
|
|
||||||
# Create the function to update the destination. We are taking advantage of
|
if is_leaf is None or not is_leaf(tree):
|
||||||
# the fact that list.extend and dict.update have the same API to simplify
|
if isinstance(tree, (list, tuple)):
|
||||||
# the code a bit.
|
for i, t in enumerate(tree):
|
||||||
if isinstance(destination, list):
|
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
||||||
_add_to_destination = destination.extend
|
return flat_tree
|
||||||
elif isinstance(destination, dict):
|
if isinstance(tree, dict):
|
||||||
_add_to_destination = destination.update
|
for k, t in tree.items():
|
||||||
else:
|
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
||||||
raise ValueError("Destination should be either a list or a dictionary or None")
|
return flat_tree
|
||||||
|
|
||||||
# Leaf identified by is_leaf so add it and return
|
return [(prefix[1:], tree)]
|
||||||
if is_leaf is not None and is_leaf(tree):
|
|
||||||
_add_to_destination([(prefix[1:], tree)])
|
|
||||||
return destination
|
|
||||||
|
|
||||||
# List or tuple so recursively add each subtree
|
|
||||||
if isinstance(tree, (list, tuple)):
|
|
||||||
for i, item in enumerate(tree):
|
|
||||||
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
|
||||||
return destination
|
|
||||||
|
|
||||||
# Dictionary so recursively add each subtree
|
|
||||||
if isinstance(tree, dict):
|
|
||||||
for key, value in tree.items():
|
|
||||||
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
|
||||||
return destination
|
|
||||||
|
|
||||||
# Leaf so add it and return
|
|
||||||
_add_to_destination([(prefix[1:], tree)])
|
|
||||||
|
|
||||||
return destination
|
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -200,34 +170,31 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
|||||||
print(d)
|
print(d)
|
||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
d = tree_unflatten({"hello.world": 42})
|
|
||||||
print(d)
|
|
||||||
# {"hello": {"world": 42}}
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
items = tree.items() if isinstance(tree, dict) else tree
|
if len(tree) == 1 and tree[0][0] == "":
|
||||||
|
return tree[0][1]
|
||||||
|
|
||||||
# Special case when we have just one element in the tree ie not a tree
|
try:
|
||||||
if len(items) == 1:
|
int(tree[0][0].split(".", maxsplit=1)[0])
|
||||||
key, value = next(iter(items))
|
is_list = True
|
||||||
if key == "":
|
except ValueError:
|
||||||
return value
|
is_list = False
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = defaultdict(list)
|
children = defaultdict(list)
|
||||||
for key, value in items:
|
for key, value in tree:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# Assume they are a list and fail to dict if the keys are not all integers
|
# recursively map them to the original container
|
||||||
try:
|
if is_list:
|
||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
@@ -235,7 +202,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
|||||||
l.extend([{} for _ in range(i - len(l))])
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
except ValueError:
|
else:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
in case ``mx.distributed.is_available()`` returns False otherwise
|
in case ``mx.distributed.is_available()`` returns False otherwise
|
||||||
it throws a runtime error. Default: ``False``
|
it throws a runtime error. Default: ``False``
|
||||||
backend (str, optional): Which distributed backend to initialize.
|
backend (str, optional): Which distributed backend to initialize.
|
||||||
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
|
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
|
||||||
available backends are tried and the first one that succeeds
|
available backends are tried and the first one that succeeds
|
||||||
becomes the global group which will be returned in subsequent
|
becomes the global group which will be returned in subsequent
|
||||||
calls. Default: ``any``
|
calls. Default: ``any``
|
||||||
|
|||||||
@@ -1,284 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx_tests
|
|
||||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
|
||||||
from mlx.nn.utils import average_gradients
|
|
||||||
|
|
||||||
|
|
||||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
world = mx.distributed.init(strict=True, backend="nccl")
|
|
||||||
rank = world.rank()
|
|
||||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
|
||||||
|
|
||||||
def test_all_reduce(self):
|
|
||||||
world = mx.distributed.init()
|
|
||||||
dtypes = [
|
|
||||||
(mx.int8, 0),
|
|
||||||
(mx.uint8, 0),
|
|
||||||
(mx.int32, 0),
|
|
||||||
(mx.uint32, 0),
|
|
||||||
(mx.float32, 1e-6),
|
|
||||||
(mx.float16, 5e-3),
|
|
||||||
(mx.bfloat16, 1e-1),
|
|
||||||
]
|
|
||||||
sizes = [
|
|
||||||
(7,),
|
|
||||||
(10,),
|
|
||||||
(1024,),
|
|
||||||
(1024, 1024),
|
|
||||||
]
|
|
||||||
key = mx.random.key(0)
|
|
||||||
|
|
||||||
for dt, rtol in dtypes:
|
|
||||||
for sh in sizes:
|
|
||||||
x = (
|
|
||||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
|
||||||
).astype(dt)
|
|
||||||
|
|
||||||
# All sum
|
|
||||||
y = mx.distributed.all_sum(x[world.rank()])
|
|
||||||
z = x.sum(0)
|
|
||||||
maxrelerror = (y - z).abs()
|
|
||||||
if rtol > 0:
|
|
||||||
maxrelerror /= z.abs()
|
|
||||||
maxrelerror = maxrelerror.max()
|
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
|
||||||
|
|
||||||
def test_average_gradients(self):
|
|
||||||
original_all_sum = mx.distributed.all_sum
|
|
||||||
n_calls = 0
|
|
||||||
xtype = None
|
|
||||||
|
|
||||||
def new_all_sum(x, **kwargs):
|
|
||||||
nonlocal n_calls
|
|
||||||
nonlocal xtype
|
|
||||||
|
|
||||||
n_calls += 1
|
|
||||||
if xtype is not None:
|
|
||||||
self.assertEqual(xtype, x.dtype)
|
|
||||||
|
|
||||||
return original_all_sum(x, **kwargs)
|
|
||||||
|
|
||||||
mx.distributed.all_sum = new_all_sum
|
|
||||||
try:
|
|
||||||
grads = [mx.ones(10) for i in range(10)]
|
|
||||||
new_grads = average_gradients(grads, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 1)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 10)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
xtype = mx.float16
|
|
||||||
new_grads = average_gradients(
|
|
||||||
grads,
|
|
||||||
all_reduce_size=2 * 50,
|
|
||||||
communication_type=mx.float16,
|
|
||||||
stream=mx.gpu,
|
|
||||||
)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
mx.distributed.all_sum = original_all_sum
|
|
||||||
|
|
||||||
def test_donation(self):
|
|
||||||
x = mx.random.normal((1024,))
|
|
||||||
mx.eval(x)
|
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
scale = mx.array(2.0)
|
|
||||||
y = mx.distributed.all_sum(x)
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_only = mx.get_peak_memory()
|
|
||||||
y = mx.distributed.all_sum(x) * scale
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
|
||||||
|
|
||||||
def test_shard_linear(self):
|
|
||||||
# Seed the prng to have the same inputs and weights generated everywhere
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
# Prepare inputs
|
|
||||||
world = mx.distributed.init()
|
|
||||||
part = (
|
|
||||||
slice(None),
|
|
||||||
slice(
|
|
||||||
world.rank() * 1024 // world.size(),
|
|
||||||
(world.rank() + 1) * 1024 // world.size(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
x = mx.random.normal((4, 1024))
|
|
||||||
|
|
||||||
# Create and shard some linear layers
|
|
||||||
lin = nn.Linear(1024, 1024, bias=True)
|
|
||||||
slin1 = shard_linear(lin, "all-to-sharded")
|
|
||||||
slin2 = shard_linear(lin, "sharded-to-all")
|
|
||||||
y = lin(x)
|
|
||||||
y1 = slin1(x)
|
|
||||||
y2 = slin2(x[part])
|
|
||||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
|
||||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
|
||||||
|
|
||||||
# Check the backward works as expected
|
|
||||||
def dummy_loss(model, x, y):
|
|
||||||
return (model(x) * y).sum()
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
nn.Linear(128, 128),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
|
||||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
|
||||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
|
||||||
)
|
|
||||||
|
|
||||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
|
||||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 128))
|
|
||||||
y = mx.random.normal((4, 128))
|
|
||||||
|
|
||||||
l1, g1 = grad1(mod, x, y)
|
|
||||||
l2, g2 = grad2(smod, x, y)
|
|
||||||
mx.eval(l1, g1, l2, g2)
|
|
||||||
|
|
||||||
part = slice(
|
|
||||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(l1, l2))
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["weight"][part],
|
|
||||||
g2["layers"][0]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["weight"][part],
|
|
||||||
g2["layers"][2]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["weight"][:, part],
|
|
||||||
g2["layers"][1]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["weight"][:, part],
|
|
||||||
g2["layers"][3]["weight"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][0]["bias"][part],
|
|
||||||
g2["layers"][0]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][2]["bias"][part],
|
|
||||||
g2["layers"][2]["bias"],
|
|
||||||
atol=1e-6,
|
|
||||||
rtol=1e-4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
mx.allclose(
|
|
||||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_shard_predicate(self):
|
|
||||||
mx.random.seed(0xF0F0F0F0)
|
|
||||||
|
|
||||||
class MyConv(nn.Module):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.aggregate = kwargs.pop("aggregate", False)
|
|
||||||
self.conv = nn.Conv2d(*args, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
if self.aggregate:
|
|
||||||
x = mx.distributed.all_sum(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def sharding(path, weight):
|
|
||||||
parts = path.split(".")
|
|
||||||
even = int(parts[1]) % 2 == 0
|
|
||||||
if even:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return -1 if parts[-1] != "bias" else None
|
|
||||||
|
|
||||||
mod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3),
|
|
||||||
)
|
|
||||||
smod = nn.Sequential(
|
|
||||||
MyConv(3, 128, kernel_size=3),
|
|
||||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
|
||||||
MyConv(128, 128, kernel_size=3),
|
|
||||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
|
||||||
)
|
|
||||||
smod.update(mod.parameters())
|
|
||||||
shard_inplace(smod, sharding)
|
|
||||||
|
|
||||||
x = mx.random.normal((4, 16, 16, 3))
|
|
||||||
y1 = mod(x)
|
|
||||||
y2 = smod(x)
|
|
||||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mlx_tests.MLXTestRunner()
|
|
||||||
@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = tree_flatten(model.parameters(), destination={})
|
params = dict(tree_flatten(model.parameters()))
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
|||||||
Reference in New Issue
Block a user