From 85510dae780078327483640f7b4e43d99f365df4 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 20 Jul 2025 03:12:57 -0700 Subject: [PATCH] Add cache --- mlx/backend/cuda/conv.cpp | 105 +++++++++++++++++++------ mlx/backend/cuda/lru_cache.h | 144 +++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 25 deletions(-) create mode 100644 mlx/backend/cuda/lru_cache.h diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index ca26d69ef..eced11ec1 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -20,11 +22,44 @@ namespace mlx::core { namespace { +struct ConvCacheKey { + int device_id; + cudnnBackendDescriptorType_t backend_type; + cudnnDataType_t cudnn_type; + std::array input_shape; + std::array filter_shape; + std::array padding_lo; + std::array padding_hi; + std::array stride; + std::array dilation; + int groups; + uint8_t input_alignment; + uint8_t filter_alignment; + uint8_t output_alignment; +}; + +auto& conv_cache() { + static LRUBytesKeyCache cache( + /* capacity */ 128); + return cache; +} + template inline std::vector convert_vector(const std::vector& vec) { return std::vector(vec.begin(), vec.end()); } +template +inline std::array fixed_vector(const std::vector& vec) { + if (vec.size() > MAX_NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + } + std::array result; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + auto nhwc_to_nchw(const array& x) { auto shape = convert_vector(x.shape()); shape.insert(shape.begin() + 1, shape.back()); @@ -57,8 +92,8 @@ inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { } } -inline int64_t get_alignment(const array& x) { - int64_t alignment = 1; +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; uintptr_t address = reinterpret_cast(x.data()); for (; alignment < 32; alignment *= 2) { if (address % (alignment * 2)) { @@ -126,22 +161,12 @@ cudnn_frontend::EngineConfigList get_engine_configs( bool execute_plan( cu::CommandEncoder& encoder, - cudnn_frontend::ManagedOpaqueDescriptor& config, - const std::string& op_graph_tag, + cudnn_frontend::ExecutionPlan& plan, const array& in, const array& wt, array& out) { - auto handle = encoder.device().cudnn_handle(); - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(config, op_graph_tag) - .build(); - - int64_t workspace_size = plan.getWorkspaceSize(); - array workspace( - allocator::malloc(workspace_size), - {static_cast(workspace_size)}, - int8); + int workspace_size = plan.getWorkspaceSize(); + array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); int64_t uids[3] = {'x', 'w', 'y'}; void* data_ptrs[3] = { @@ -158,28 +183,34 @@ bool execute_plan( auto capture = encoder.capture_context(); if (cudnnBackendExecute( - handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != - CUDNN_STATUS_SUCCESS) { + encoder.device().cudnn_handle(), + plan.get_raw_desc(), + variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) { // Discard the captured graph when failed. capture.discard = true; return false; } - encoder.add_completed_handler([plan = std::move(plan)]() {}); encoder.add_temporary(workspace); return true; } -bool execute_plans( +bool try_engines( cu::CommandEncoder& encoder, cudnn_frontend::EngineConfigList& configs, + const ConvCacheKey& cache_key, const std::string& op_graph_tag, const array& in, const array& wt, array& out) { for (auto& config : configs) { try { - if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) { + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(encoder.device().cudnn_handle()) + .setEngineConfig(config, op_graph_tag) + .build(); + if (execute_plan(encoder, plan, in, wt, out)) { + conv_cache().emplace(cache_key, std::move(plan)); return true; } } catch (cudnn_frontend::cudnnException&) { @@ -219,12 +250,36 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(wt); encoder.set_output_array(out); - // TODO: Searching a working execution plan is expensive, add cache. + auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; + auto cudnn_type = dtype_to_cudnn_type(in.dtype()); + + // Search cache. + ConvCacheKey cache_key{ + encoder.device().cuda_device(), + backend_type, + cudnn_type, + fixed_vector(in.shape()), + fixed_vector(wt.shape()), + fixed_vector(padding_lo_), + fixed_vector(padding_hi_), + fixed_vector(kernel_strides_), + fixed_vector(kernel_dilation_), + groups_, + get_alignment(in), + get_alignment(wt), + get_alignment(out)}; + auto it = conv_cache().find(cache_key); + if (it != conv_cache().end()) { + if (!execute_plan(encoder, it->second, in, wt, out)) { + throw std::runtime_error("Cached convolution plan failed to execute."); + } + return; + } // Build operation graph. auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) ? CUDNN_DATA_FLOAT - : dtype_to_cudnn_type(in.dtype()); + : cudnn_type; auto stride = convert_vector(kernel_strides_); auto padding_lo = convert_vector(padding_lo_); @@ -241,7 +296,6 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { .setDilation(dilation.size(), dilation.data()) .build(); - auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; auto op = cudnn_frontend::OperationBuilder(backend_type) .setxDesc(build_tensor('x', in)) .setwDesc(build_tensor('w', wt)) @@ -257,12 +311,13 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { // Try to run plans based on heuristics. auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); - if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { + auto op_graph_tag = op_graph.getTag(); + if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { return; } // Then try fallback plans. configs = get_engine_configs(backend_type, in.dtype(), op_graph); - if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { + if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { return; } throw std::runtime_error("Unable to find an engine for convolution."); diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h new file mode 100644 index 000000000..7294f2477 --- /dev/null +++ b/mlx/backend/cuda/lru_cache.h @@ -0,0 +1,144 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template < + typename K, + typename V, + template typename M = std::unordered_map> +class LRUCache { + public: + using value_type = std::pair; + using list_type = std::list; + using list_iter = typename list_type::iterator; + using map_type = M; + + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t size() const { + return map_.size(); + } + size_t capacity() const { + return capacity_; + } + bool empty() const { + return vlist_.empty(); + } + + void resize(size_t new_capacity) { + capacity_ = new_capacity; + trim(); + } + + auto begin() { + return vlist_.begin(); + } + auto begin() const { + return vlist_.begin(); + } + auto end() { + return vlist_.end(); + } + auto end() const { + return vlist_.end(); + } + + void clear() { + map_.clear(); + vlist_.clear(); + } + + list_iter find(const K& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + std::pair emplace(const K& key, V value) { + auto it = map_.find(key); + if (it != map_.end()) { + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + vlist_.emplace_front(key, std::move(value)); + map_[key] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + list_iter erase(list_iter pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + list_type vlist_; + map_type map_; + size_t capacity_; +}; + +// Turn a POD struct into a container key by doing bytes compare. +template +struct BytesKey { + T pod; + static_assert(std::is_standard_layout_v, "T is not POD"); + + BytesKey(T pod) : pod(std::move(pod)) {} + + BytesKey(const BytesKey& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + BytesKey(BytesKey&& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + bool operator==(const BytesKey& other) const { + auto* ptr1 = reinterpret_cast(&pod); + auto* ptr2 = reinterpret_cast(&other.pod); + return memcmp(ptr1, ptr2, sizeof(T)) == 0; + } +}; + +// Compute hash according to the bytes value of T. +template +struct BytesHash { + static_assert(std::is_standard_layout_v, "T is not POD"); + + size_t operator()(const T& pod) const { + auto* ptr = reinterpret_cast(&pod); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < sizeof(T); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return value; + } +}; + +template +using BytesKeyHashMap = std::unordered_map>; + +template +using LRUBytesKeyCache = LRUCache, V, BytesKeyHashMap>; + +} // namespace mlx::core