mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN * Initial implementation * Remove backend apis * Fix recording cudnn conv * More unused backend apis * Fix C++ conv tests * include cudnn as python dep * Install libcudnn9-dev-cuda-12 in CI * cudnn only accepts contiguous inputs * Switch to backend apis * Plan needs to be kept alive * Turn off tf32 * Add cache * Test the native cuda graph api * Set cudnn stream before execution * Make LRUCache more like a normal container * Do error check for cublas handle * Zero-initilizing array * Use tf32 for conv * Skip TestConv.test_torch_conv_2D test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cudnn.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
@@ -21,6 +22,7 @@ class CommandEncoder {
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc);
|
||||
@@ -65,6 +67,11 @@ class CommandEncoder {
|
||||
void
|
||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
@@ -73,6 +80,10 @@ class CommandEncoder {
|
||||
void maybe_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
CudaStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
@@ -137,12 +148,16 @@ class Device {
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
cudnnHandle_t cudnn_handle() const {
|
||||
return cudnn_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user