mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 00:36:49 +08:00
[CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN * Initial implementation * Remove backend apis * Fix recording cudnn conv * More unused backend apis * Fix C++ conv tests * include cudnn as python dep * Install libcudnn9-dev-cuda-12 in CI * cudnn only accepts contiguous inputs * Switch to backend apis * Plan needs to be kept alive * Turn off tf32 * Add cache * Test the native cuda graph api * Set cudnn stream before execution * Make LRUCache more like a normal container * Do error check for cublas handle * Zero-initilizing array * Use tf32 for conv * Skip TestConv.test_torch_conv_2D test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
70dc336785
commit
6f5874a2f2
@ -216,6 +216,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
@ -385,7 +386,7 @@ jobs:
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt install cuda-toolkit-12-9
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
|
@ -15,6 +15,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Use the frontend APIs of cuDNN.
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.12.1
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||
FetchContent_MakeAvailable(cudnn)
|
||||
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||
# Link with the actual cuDNN libraries.
|
||||
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
340
mlx/backend/cuda/conv.cpp
Normal file
340
mlx/backend/cuda/conv.cpp
Normal file
@ -0,0 +1,340 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
cudnnDataType_t cudnn_type;
|
||||
std::array<int, MAX_NDIM> input_shape;
|
||||
std::array<int, MAX_NDIM> filter_shape;
|
||||
std::array<int, MAX_NDIM> padding_lo;
|
||||
std::array<int, MAX_NDIM> padding_hi;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
std::array<int, MAX_NDIM> dilation;
|
||||
int groups;
|
||||
uint8_t input_alignment;
|
||||
uint8_t filter_alignment;
|
||||
uint8_t output_alignment;
|
||||
};
|
||||
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||
/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(shape, strides);
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
if (use_fallback) {
|
||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
};
|
||||
} else {
|
||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
};
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
auto configs = generator.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
const_cast<void*>(in.data<void>()),
|
||||
const_cast<void*>(wt.data<void>()),
|
||||
out.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(3, data_ptrs)
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const ConvCacheKey& cache_key,
|
||||
const std::string& op_graph_tag,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, in, wt, out)) {
|
||||
conv_cache().emplace(cache_key, std::move(plan));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException&) {
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// cuDNN requires contiguous input.
|
||||
// TODO: Handle NCHW format specially.
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
backend_type,
|
||||
cudnn_type,
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
groups_,
|
||||
get_alignment(in),
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Build operation graph.
|
||||
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: cudnn_type;
|
||||
|
||||
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||
auto dilation = convert_vector<int64_t>(kernel_dilation_);
|
||||
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_data_type)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', in))
|
||||
.setwDesc(build_tensor('w', wt))
|
||||
.setyDesc(build_tensor('y', out))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
auto op_graph = cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -9,12 +9,23 @@
|
||||
#include <future>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
if (err != CUDNN_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
@ -22,7 +33,7 @@ int cuda_graph_cache_size() {
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
namespace cu {
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
@ -40,11 +51,14 @@ Device::Device(int device) : device_(device) {
|
||||
}
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
cublasLtCreate(<_);
|
||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||
// The cudnn handle is used by Convolution.
|
||||
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
cublasLtDestroy(lt_);
|
||||
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
@ -66,29 +80,36 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.device().make_current();
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract and add as single kernel node when possible.
|
||||
size_t num_nodes;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||
if (num_nodes == 1) {
|
||||
cudaGraphNode_t captured_node;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
} else {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
|
||||
enc.insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||
if (type == cudaGraphNodeTypeKernel) {
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
enc.add_kernel_node(params);
|
||||
return;
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
|
||||
// Otherwise add the captured graph as subgraph.
|
||||
enc.add_graph_node(graph);
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
@ -221,10 +242,7 @@ void CommandEncoder::add_kernel_node(
|
||||
kernel_params.gridDim = grid_dim;
|
||||
kernel_params.blockDim = block_dim;
|
||||
kernel_params.kernelParams = params;
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
@ -241,12 +259,27 @@ void CommandEncoder::add_kernel_node(
|
||||
kernel_params.blockDimY = block_dim.y;
|
||||
kernel_params.blockDimZ = block_dim.z;
|
||||
kernel_params.kernelParams = params;
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
@ -331,6 +364,4 @@ CommandEncoder& get_command_encoder(Stream s) {
|
||||
return device(s.device).get_command_encoder(s);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cudnn.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
@ -21,6 +22,7 @@ class CommandEncoder {
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc);
|
||||
@ -65,6 +67,11 @@ class CommandEncoder {
|
||||
void
|
||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
@ -73,6 +80,10 @@ class CommandEncoder {
|
||||
void maybe_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
CudaStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
@ -137,12 +148,16 @@ class Device {
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
cudnnHandle_t cudnn_handle() const {
|
||||
return cudnn_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
};
|
||||
|
||||
|
146
mlx/backend/cuda/lru_cache.h
Normal file
146
mlx/backend/cuda/lru_cache.h
Normal file
@ -0,0 +1,146 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <
|
||||
typename K,
|
||||
typename V,
|
||||
template <typename...> typename M = std::unordered_map>
|
||||
class LRUCache {
|
||||
public:
|
||||
using value_type = std::pair<K, V>;
|
||||
using list_type = std::list<value_type>;
|
||||
using iterator = typename list_type::iterator;
|
||||
using const_iterator = typename list_type::const_iterator;
|
||||
using map_type = M<K, iterator>;
|
||||
|
||||
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
||||
|
||||
size_t size() const {
|
||||
return map_.size();
|
||||
}
|
||||
size_t capacity() const {
|
||||
return capacity_;
|
||||
}
|
||||
bool empty() const {
|
||||
return vlist_.empty();
|
||||
}
|
||||
|
||||
void resize(size_t new_capacity) {
|
||||
capacity_ = new_capacity;
|
||||
trim();
|
||||
}
|
||||
|
||||
iterator begin() {
|
||||
return vlist_.begin();
|
||||
}
|
||||
const_iterator begin() const {
|
||||
return vlist_.begin();
|
||||
}
|
||||
iterator end() {
|
||||
return vlist_.end();
|
||||
}
|
||||
const_iterator end() const {
|
||||
return vlist_.end();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
map_.clear();
|
||||
vlist_.clear();
|
||||
}
|
||||
|
||||
iterator find(const K& key) {
|
||||
auto it = map_.find(key);
|
||||
if (it == map_.end())
|
||||
return end();
|
||||
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
std::pair<iterator, bool> emplace(const K& key, U&& value) {
|
||||
auto it = map_.find(key);
|
||||
if (it != map_.end()) {
|
||||
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||
return {it->second, false};
|
||||
}
|
||||
|
||||
vlist_.emplace_front(key, std::forward<U>(value));
|
||||
map_[key] = vlist_.begin();
|
||||
|
||||
trim();
|
||||
|
||||
return {vlist_.begin(), true};
|
||||
}
|
||||
|
||||
iterator erase(iterator pos) {
|
||||
map_.erase(pos->first);
|
||||
return vlist_.erase(pos);
|
||||
}
|
||||
|
||||
private:
|
||||
void trim() {
|
||||
while (map_.size() > capacity_) {
|
||||
auto last = std::prev(vlist_.end());
|
||||
map_.erase(last->first);
|
||||
vlist_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
list_type vlist_;
|
||||
map_type map_;
|
||||
size_t capacity_;
|
||||
};
|
||||
|
||||
// Turn a POD struct into a container key by doing bytes compare.
|
||||
template <typename T>
|
||||
struct BytesKey {
|
||||
T pod;
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||
|
||||
BytesKey(const BytesKey& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
}
|
||||
|
||||
BytesKey(BytesKey&& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
}
|
||||
|
||||
bool operator==(const BytesKey& other) const {
|
||||
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
|
||||
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
|
||||
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Compute hash according to the bytes value of T.
|
||||
template <typename T>
|
||||
struct BytesHash {
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
size_t operator()(const T& pod) const {
|
||||
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
|
||||
uint32_t value = 0x811C9DC5;
|
||||
for (int i = 0; i < sizeof(T); ++i) {
|
||||
value ^= ptr[i];
|
||||
value *= 0x01000193;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K, typename V>
|
||||
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
|
||||
|
||||
template <typename K, typename V>
|
||||
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
|
||||
|
||||
} // namespace mlx::core
|
@ -8,7 +8,6 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
@ -18,16 +17,6 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||
}
|
||||
}
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
|
@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
}
|
||||
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
|
@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||
}
|
||||
}
|
||||
|
||||
void check_cuda_error(const char* name, cudaError_t err) {
|
||||
if (err != cudaSuccess) {
|
||||
throw std::runtime_error(
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@ -33,10 +34,12 @@ class CudaStream {
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
|
@ -16,7 +16,7 @@ rm "${repaired_wheel}"
|
||||
mlx_so="mlx/lib/libmlx.so"
|
||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||
base="\$ORIGIN/../../nvidia"
|
||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
|
||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||
python ../python/scripts/repair_record.py ${mlx_so}
|
||||
|
||||
|
@ -15,19 +15,12 @@ cuda_skip = {
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
# Convolutions NYI
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
@ -40,10 +33,6 @@ cuda_skip = {
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestVmap.test_vmap_conv",
|
||||
# FFTs NYI
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
|
Loading…
Reference in New Issue
Block a user