diff --git a/.circleci/config.yml b/.circleci/config.yml index 60907079cf..f33e80f0cd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -216,6 +216,7 @@ jobs: name: Install Python package command: | sudo apt-get update + sudo apt-get install libcudnn9-dev-cuda-12 sudo apt-get install libblas-dev liblapack-dev liblapacke-dev python3 -m venv env source env/bin/activate @@ -385,7 +386,7 @@ jobs: wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt-get update - sudo apt install cuda-toolkit-12-9 + sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install zip pip install auditwheel diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 0be8e9f863..49658dcd82 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -15,6 +15,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp @@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt) # Use NVRTC and driver APIs. target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) +# Use the frontend APIs of cuDNN. +FetchContent_Declare( + cudnn + GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git + GIT_TAG v1.12.1 + GIT_SHALLOW TRUE + EXCLUDE_FROM_ALL) +set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) +set(CUDNN_FRONTEND_BUILD_SAMPLES OFF) +set(CUDNN_FRONTEND_BUILD_TESTS OFF) +set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF) +FetchContent_MakeAvailable(cudnn) +target_link_libraries(mlx PRIVATE cudnn_frontend) +# Link with the actual cuDNN libraries. +include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake) +target_link_libraries(mlx PRIVATE CUDNN::cudnn_all) + # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp new file mode 100644 index 0000000000..578b520a04 --- /dev/null +++ b/mlx/backend/cuda/conv.cpp @@ -0,0 +1,340 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +// cudnn_frontend.h redefines this macro. +#undef CHECK_CUDA_ERROR + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +// Not all engines support it so can not use this API now. +#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 + +struct ConvCacheKey { + int device_id; + cudnnBackendDescriptorType_t backend_type; + cudnnDataType_t cudnn_type; + std::array input_shape; + std::array filter_shape; + std::array padding_lo; + std::array padding_hi; + std::array stride; + std::array dilation; + int groups; + uint8_t input_alignment; + uint8_t filter_alignment; + uint8_t output_alignment; +}; + +auto& conv_cache() { + static LRUBytesKeyCache cache( + /* capacity */ 128); + return cache; +} + +template +inline std::vector convert_vector(const std::vector& vec) { + return std::vector(vec.begin(), vec.end()); +} + +template +inline std::array fixed_vector(const std::vector& vec) { + if (vec.size() > MAX_NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + } + std::array result = {}; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +auto nhwc_to_nchw(const array& x) { + auto shape = convert_vector(x.shape()); + shape.insert(shape.begin() + 1, shape.back()); + shape.erase(shape.end() - 1); + auto strides = convert_vector(x.strides()); + strides.insert(strides.begin() + 1, strides.back()); + strides.erase(strides.end() - 1); + return std::make_tuple(shape, strides); +} + +inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return CUDNN_DATA_INT8; + case int32: + return CUDNN_DATA_INT32; + case uint8: + return CUDNN_DATA_UINT8; + case float16: + return CUDNN_DATA_HALF; + case bfloat16: + return CUDNN_DATA_BFLOAT16; + case float32: + return CUDNN_DATA_FLOAT; + case float64: + return CUDNN_DATA_DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); + } +} + +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(x.data()); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) { + auto [shape, strides] = nhwc_to_nchw(x); + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(get_alignment(x)) + .setDataType(dtype_to_cudnn_type(x.dtype())) + .build(); +} + +cudnn_frontend::EngineConfigList get_engine_configs( + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph, + bool use_fallback = false) { + cudnn_frontend::GeneratorSource source; + if (use_fallback) { + source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) { + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(op_graph) + .setOperation(backend_type) + .build(); + return fallback.getFallbackList(); + }; + } else { + source = [](cudnn_frontend::OperationGraph& op_graph) { + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(op_graph) + .setHeurMode(CUDNN_HEUR_MODE_A) + .build(); + return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + }; + } + + cudnn_frontend::EngineConfigGenerator generator(1, &source); + auto configs = generator.generate_engine_config(op_graph); + + cudnn_frontend::EngineConfigList filtered_configs; + cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { + if (cudnn_frontend::hasNumericalNote< + CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { + return true; + } + if (cudnn_frontend::hasNumericalNote(c) && + dtype == float32 && !env::enable_tf32()) { + return true; + } + return false; + }); + return filtered_configs; +} + +bool execute_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + const array& in, + const array& wt, + array& out) { + int workspace_size = plan.getWorkspaceSize(); + array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); + + int64_t uids[3] = {'x', 'w', 'y'}; + void* data_ptrs[3] = { + const_cast(in.data()), + const_cast(wt.data()), + out.data(), + }; + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace.data()) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + +#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API + cudaGraph_t graph; + cudaGraphCreate(&graph, 0); + std::unique_ptr graph_freer( + &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); + if (cudnnBackendPopulateCudaGraph( + handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != + CUDNN_STATUS_SUCCESS) { + return false; + } + encoder.add_graph_node(graph); +#else + auto capture = encoder.capture_context(); + if (cudnnBackendExecute( + handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != + CUDNN_STATUS_SUCCESS) { + // Discard the captured graph when failed. + capture.discard = true; + return false; + } +#endif + + encoder.add_temporary(workspace); + return true; +} + +bool try_engines( + cu::CommandEncoder& encoder, + cudnn_frontend::EngineConfigList& configs, + const ConvCacheKey& cache_key, + const std::string& op_graph_tag, + const array& in, + const array& wt, + array& out) { + for (auto& config : configs) { + try { + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(encoder.device().cudnn_handle()) + .setEngineConfig(config, op_graph_tag) + .build(); + if (execute_plan(encoder, plan, in, wt, out)) { + conv_cache().emplace(cache_key, std::move(plan)); + return true; + } + } catch (cudnn_frontend::cudnnException&) { + } + } + return false; +} + +} // namespace + +void Convolution::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Convolution::eval_gpu"); + if (out.size() == 0) { + return; + } + + assert(inputs.size() == 2); + array in = inputs[0]; + array wt = inputs[1]; + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + // cuDNN requires contiguous input. + // TODO: Handle NCHW format specially. + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); + + auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; + auto cudnn_type = dtype_to_cudnn_type(in.dtype()); + + // Search cache. + ConvCacheKey cache_key{ + encoder.device().cuda_device(), + backend_type, + cudnn_type, + fixed_vector(in.shape()), + fixed_vector(wt.shape()), + fixed_vector(padding_lo_), + fixed_vector(padding_hi_), + fixed_vector(kernel_strides_), + fixed_vector(kernel_dilation_), + groups_, + get_alignment(in), + get_alignment(wt), + get_alignment(out)}; + if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { + if (!execute_plan(encoder, it->second, in, wt, out)) { + throw std::runtime_error("Cached convolution plan failed to execute."); + } + return; + } + + // Build operation graph. + auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) + ? CUDNN_DATA_FLOAT + : cudnn_type; + + auto stride = convert_vector(kernel_strides_); + auto padding_lo = convert_vector(padding_lo_); + auto padding_hi = convert_vector(padding_hi_); + auto dilation = convert_vector(kernel_dilation_); + + auto conv_desc = cudnn_frontend::ConvDescBuilder() + .setDataType(compute_data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(stride.size()) + .setStrides(stride.size(), stride.data()) + .setPrePadding(padding_lo.size(), padding_lo.data()) + .setPostPadding(padding_hi.size(), padding_hi.data()) + .setDilation(dilation.size(), dilation.data()) + .build(); + + auto op = cudnn_frontend::OperationBuilder(backend_type) + .setxDesc(build_tensor('x', in)) + .setwDesc(build_tensor('w', wt)) + .setyDesc(build_tensor('y', out)) + .setcDesc(conv_desc) + .build(); + + std::array ops = {&op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(encoder.device().cudnn_handle()) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Try to run plans based on heuristics. + auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); + auto op_graph_tag = op_graph.getTag(); + if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { + return; + } + // Then try fallback plans. + configs = get_engine_configs(backend_type, in.dtype(), op_graph); + if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { + return; + } + throw std::runtime_error("Unable to find an engine for convolution."); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 366cdf8269..41a583996e 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -9,12 +9,23 @@ #include #include -namespace mlx::core { +namespace mlx::core::cu { + +namespace { // Can be tuned with MLX_MAX_OPS_PER_BUFFER // This should be less than 255 constexpr int default_max_nodes_per_graph = 20; +#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) + +void check_cudnn_error(const char* name, cudnnStatus_t err) { + if (err != CUDNN_STATUS_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); + } +} + int cuda_graph_cache_size() { static int cache_size = []() { return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); @@ -22,7 +33,7 @@ int cuda_graph_cache_size() { return cache_size; } -namespace cu { +} // namespace Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( @@ -40,11 +51,14 @@ Device::Device(int device) : device_(device) { } // The cublasLt handle is used by matmul. make_current(); - cublasLtCreate(<_); + CHECK_CUBLAS_ERROR(cublasLtCreate(<_)); + // The cudnn handle is used by Convolution. + CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); } Device::~Device() { - cublasLtDestroy(lt_); + CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_)); + CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_)); } void Device::make_current() { @@ -66,29 +80,36 @@ CommandEncoder& Device::get_command_encoder(Stream s) { } CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + enc.device().make_current(); CHECK_CUDA_ERROR( cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } CommandEncoder::CaptureContext::~CaptureContext() { CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + std::unique_ptr graph_freer( + &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + if (discard) { + return; + } + + // Extract and add as single kernel node when possible. size_t num_nodes; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); if (num_nodes == 1) { cudaGraphNode_t captured_node; CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); - CUDA_KERNEL_NODE_PARAMS params; - CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); - cudaGraphNode_t node; - CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); - enc.insert_graph_dependencies(GraphNode{node, 'K'}); - } else { - cudaGraphNode_t node; - CHECK_CUDA_ERROR( - cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); - enc.insert_graph_dependencies(GraphNode{node, 'G'}); + cudaGraphNodeType type; + CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type)); + if (type == cudaGraphNodeTypeKernel) { + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + enc.add_kernel_node(params); + return; + } } - CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); + // Otherwise add the captured graph as subgraph. + enc.add_graph_node(graph); } CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) @@ -221,10 +242,7 @@ void CommandEncoder::add_kernel_node( kernel_params.gridDim = grid_dim; kernel_params.blockDim = block_dim; kernel_params.kernelParams = params; - cudaGraphNode_t node; - CHECK_CUDA_ERROR( - cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); - insert_graph_dependencies(GraphNode{node, 'K'}); + add_kernel_node(kernel_params); } void CommandEncoder::add_kernel_node( @@ -241,12 +259,27 @@ void CommandEncoder::add_kernel_node( kernel_params.blockDimY = block_dim.y; kernel_params.blockDimZ = block_dim.z; kernel_params.kernelParams = params; - CUgraphNode node; - CHECK_CUDA_ERROR( - cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + add_kernel_node(kernel_params); +} + +void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); insert_graph_dependencies(GraphNode{node, 'K'}); } +void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { + CUgraphNode node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_graph_node(cudaGraph_t child) { + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); + insert_graph_dependencies(GraphNode{node, 'G'}); +} + void CommandEncoder::commit() { if (!temporaries_.empty()) { add_completed_handler([temporaries = std::move(temporaries_)]() {}); @@ -331,6 +364,4 @@ CommandEncoder& get_command_encoder(Stream s) { return device(s.device).get_command_encoder(s); } -} // namespace cu - -} // namespace mlx::core +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 8ac840cbb4..25c26fb0d2 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -21,6 +22,7 @@ class CommandEncoder { ~CaptureContext(); cudaGraph_t graph; CommandEncoder& enc; + bool discard{false}; }; struct ConcurrentContext { ConcurrentContext(CommandEncoder& enc); @@ -65,6 +67,11 @@ class CommandEncoder { void add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + // Low-level graph helpers. + void add_kernel_node(const cudaKernelNodeParams& params); + void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); + void add_graph_node(cudaGraph_t child); + void add_temporary(const array& arr) { temporaries_.push_back(arr.data_shared_ptr()); } @@ -73,6 +80,10 @@ class CommandEncoder { void maybe_commit(); void commit(); + Device& device() { + return device_; + } + CudaStream& stream() { return stream_; } @@ -137,12 +148,16 @@ class Device { cublasLtHandle_t lt_handle() const { return lt_; } + cudnnHandle_t cudnn_handle() const { + return cudnn_; + } private: int device_; int compute_capability_major_; int compute_capability_minor_; cublasLtHandle_t lt_; + cudnnHandle_t cudnn_; std::unordered_map encoders_; }; diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h new file mode 100644 index 0000000000..c8df2fa93c --- /dev/null +++ b/mlx/backend/cuda/lru_cache.h @@ -0,0 +1,146 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template < + typename K, + typename V, + template typename M = std::unordered_map> +class LRUCache { + public: + using value_type = std::pair; + using list_type = std::list; + using iterator = typename list_type::iterator; + using const_iterator = typename list_type::const_iterator; + using map_type = M; + + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t size() const { + return map_.size(); + } + size_t capacity() const { + return capacity_; + } + bool empty() const { + return vlist_.empty(); + } + + void resize(size_t new_capacity) { + capacity_ = new_capacity; + trim(); + } + + iterator begin() { + return vlist_.begin(); + } + const_iterator begin() const { + return vlist_.begin(); + } + iterator end() { + return vlist_.end(); + } + const_iterator end() const { + return vlist_.end(); + } + + void clear() { + map_.clear(); + vlist_.clear(); + } + + iterator find(const K& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + template + std::pair emplace(const K& key, U&& value) { + auto it = map_.find(key); + if (it != map_.end()) { + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + vlist_.emplace_front(key, std::forward(value)); + map_[key] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + iterator erase(iterator pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + list_type vlist_; + map_type map_; + size_t capacity_; +}; + +// Turn a POD struct into a container key by doing bytes compare. +template +struct BytesKey { + T pod; + static_assert(std::is_standard_layout_v, "T is not POD"); + + BytesKey(T pod) : pod(std::move(pod)) {} + + BytesKey(const BytesKey& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + BytesKey(BytesKey&& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + bool operator==(const BytesKey& other) const { + auto* ptr1 = reinterpret_cast(&pod); + auto* ptr2 = reinterpret_cast(&other.pod); + return memcmp(ptr1, ptr2, sizeof(T)) == 0; + } +}; + +// Compute hash according to the bytes value of T. +template +struct BytesHash { + static_assert(std::is_standard_layout_v, "T is not POD"); + + size_t operator()(const T& pod) const { + auto* ptr = reinterpret_cast(&pod); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < sizeof(T); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return value; + } +}; + +template +using BytesKeyHashMap = std::unordered_map>; + +template +using LRUBytesKeyCache = LRUCache, V, BytesKeyHashMap>; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 1bca7c7301..efddf25067 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -8,7 +8,6 @@ #include "mlx/primitives.h" #include "mlx/utils.h" -#include #include #include @@ -18,16 +17,6 @@ namespace mlx::core { namespace cu { -#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) - -void check_cublas_error(const char* name, cublasStatus_t err) { - if (err != CUBLAS_STATUS_SUCCESS) { - // TODO: Use cublasGetStatusString when it is widely available. - throw std::runtime_error( - fmt::format("{} failed with code: {}.", name, static_cast(err))); - } -} - struct CublasPreference { CublasPreference(Device& device) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index a7f4e8f666..0451c9e546 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback( } NO_GPU(BlockMaskedMM) -NO_GPU(Convolution) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 1c12fa4df2..baab3b2a53 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -17,6 +17,14 @@ CudaStream::~CudaStream() { CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); } +void check_cublas_error(const char* name, cublasStatus_t err) { + if (err != CUBLAS_STATUS_SUCCESS) { + // TODO: Use cublasGetStatusString when it is widely available. + throw std::runtime_error( + fmt::format("{} failed with code: {}.", name, static_cast(err))); + } +} + void check_cuda_error(const char* name, cudaError_t err) { if (err != cudaSuccess) { throw std::runtime_error( diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index bfb02c5b65..f2b6b16cb8 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include @@ -33,10 +34,12 @@ class CudaStream { }; // Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, CUresult err); // The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) // Convert Dtype to CUDA C++ types. diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh index 60586a0c2a..60e2ab11e6 100644 --- a/python/scripts/repair_cuda.sh +++ b/python/scripts/repair_cuda.sh @@ -16,7 +16,7 @@ rm "${repaired_wheel}" mlx_so="mlx/lib/libmlx.so" rpath=$(patchelf --print-rpath "${mlx_so}") base="\$ORIGIN/../../nvidia" -rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib +rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" python ../python/scripts/repair_record.py ${mlx_so} diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 50cb8dcbe3..5bb465e1eb 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -15,19 +15,12 @@ cuda_skip = { "TestOps.test_hadamard_grad_vmap", # Convolutions NYI "TestConv.test_1d_conv_with_2d", - "TestConv.test_asymmetric_padding", - "TestConv.test_basic_grad_shapes", - "TestConv.test_conv2d_unaligned_channels", "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", "TestConv.test_conv_groups_grad", - "TestConv.test_numpy_conv", - "TestConv.test_repeated_conv", - "TestConv.test_torch_conv_1D", "TestConv.test_torch_conv_1D_grad", "TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_2D_grad", - "TestConv.test_torch_conv_3D", "TestConv.test_torch_conv_3D_grad", "TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_general", @@ -40,10 +33,6 @@ cuda_skip = { "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", - "TestExportImport.test_export_conv", - "TestLayers.test_conv1d", - "TestLayers.test_conv2d", - "TestVmap.test_vmap_conv", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", diff --git a/setup.py b/setup.py index f626277490..8a55171126 100644 --- a/setup.py +++ b/setup.py @@ -289,6 +289,7 @@ if __name__ == "__main__": install_requires += [ "nvidia-cublas-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*", + "nvidia-cudnn-cu12==12.9.*", ] else: name = "mlx-cpu"