Add cache

This commit is contained in:
Cheng
2025-07-20 03:12:57 -07:00
parent 0430a6a74a
commit 85510dae78
2 changed files with 224 additions and 25 deletions

View File

@@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -20,11 +22,44 @@ namespace mlx::core {
namespace { namespace {
struct ConvCacheKey {
int device_id;
cudnnBackendDescriptorType_t backend_type;
cudnnDataType_t cudnn_type;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape;
std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation;
int groups;
uint8_t input_alignment;
uint8_t filter_alignment;
uint8_t output_alignment;
};
auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
/* capacity */ 128);
return cache;
}
template <typename T, typename U> template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) { inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end()); return std::vector<T>(vec.begin(), vec.end());
} }
template <typename T>
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) { auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape()); auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back()); shape.insert(shape.begin() + 1, shape.back());
@@ -57,8 +92,8 @@ inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
} }
} }
inline int64_t get_alignment(const array& x) { inline uint8_t get_alignment(const array& x) {
int64_t alignment = 1; uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>()); uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) { for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) { if (address % (alignment * 2)) {
@@ -126,22 +161,12 @@ cudnn_frontend::EngineConfigList get_engine_configs(
bool execute_plan( bool execute_plan(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::ManagedOpaqueDescriptor& config, cudnn_frontend::ExecutionPlan& plan,
const std::string& op_graph_tag,
const array& in, const array& in,
const array& wt, const array& wt,
array& out) { array& out) {
auto handle = encoder.device().cudnn_handle(); int workspace_size = plan.getWorkspaceSize();
auto plan = cudnn_frontend::ExecutionPlanBuilder() array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
int64_t workspace_size = plan.getWorkspaceSize();
array workspace(
allocator::malloc(workspace_size),
{static_cast<int>(workspace_size)},
int8);
int64_t uids[3] = {'x', 'w', 'y'}; int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = { void* data_ptrs[3] = {
@@ -158,28 +183,34 @@ bool execute_plan(
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
if (cudnnBackendExecute( if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != encoder.device().cudnn_handle(),
CUDNN_STATUS_SUCCESS) { plan.get_raw_desc(),
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed. // Discard the captured graph when failed.
capture.discard = true; capture.discard = true;
return false; return false;
} }
encoder.add_completed_handler([plan = std::move(plan)]() {});
encoder.add_temporary(workspace); encoder.add_temporary(workspace);
return true; return true;
} }
bool execute_plans( bool try_engines(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs, cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key,
const std::string& op_graph_tag, const std::string& op_graph_tag,
const array& in, const array& in,
const array& wt, const array& wt,
array& out) { array& out) {
for (auto& config : configs) { for (auto& config : configs) {
try { try {
if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) { auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, in, wt, out)) {
conv_cache().emplace(cache_key, std::move(plan));
return true; return true;
} }
} catch (cudnn_frontend::cudnnException&) { } catch (cudnn_frontend::cudnnException&) {
@@ -219,12 +250,36 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(out);
// TODO: Searching a working execution plan is expensive, add cache. auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),
backend_type,
cudnn_type,
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_strides_),
fixed_vector(kernel_dilation_),
groups_,
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
auto it = conv_cache().find(cache_key);
if (it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) {
throw std::runtime_error("Cached convolution plan failed to execute.");
}
return;
}
// Build operation graph. // Build operation graph.
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
? CUDNN_DATA_FLOAT ? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(in.dtype()); : cudnn_type;
auto stride = convert_vector<int64_t>(kernel_strides_); auto stride = convert_vector<int64_t>(kernel_strides_);
auto padding_lo = convert_vector<int64_t>(padding_lo_); auto padding_lo = convert_vector<int64_t>(padding_lo_);
@@ -241,7 +296,6 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
.setDilation(dilation.size(), dilation.data()) .setDilation(dilation.size(), dilation.data())
.build(); .build();
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
auto op = cudnn_frontend::OperationBuilder(backend_type) auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in)) .setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt)) .setwDesc(build_tensor('w', wt))
@@ -257,12 +311,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// Try to run plans based on heuristics. // Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph); auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { auto op_graph_tag = op_graph.getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return; return;
} }
// Then try fallback plans. // Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph); configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) { if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return; return;
} }
throw std::runtime_error("Unable to find an engine for convolution."); throw std::runtime_error("Unable to find an engine for convolution.");

View File

@@ -0,0 +1,144 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <list>
#include <unordered_map>
#include <utility>
namespace mlx::core {
template <
typename K,
typename V,
template <typename...> typename M = std::unordered_map>
class LRUCache {
public:
using value_type = std::pair<K, V>;
using list_type = std::list<value_type>;
using list_iter = typename list_type::iterator;
using map_type = M<K, list_iter>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
size_t size() const {
return map_.size();
}
size_t capacity() const {
return capacity_;
}
bool empty() const {
return vlist_.empty();
}
void resize(size_t new_capacity) {
capacity_ = new_capacity;
trim();
}
auto begin() {
return vlist_.begin();
}
auto begin() const {
return vlist_.begin();
}
auto end() {
return vlist_.end();
}
auto end() const {
return vlist_.end();
}
void clear() {
map_.clear();
vlist_.clear();
}
list_iter find(const K& key) {
auto it = map_.find(key);
if (it == map_.end())
return end();
vlist_.splice(vlist_.begin(), vlist_, it->second);
return it->second;
}
std::pair<list_iter, bool> emplace(const K& key, V value) {
auto it = map_.find(key);
if (it != map_.end()) {
vlist_.splice(vlist_.begin(), vlist_, it->second);
return {it->second, false};
}
vlist_.emplace_front(key, std::move(value));
map_[key] = vlist_.begin();
trim();
return {vlist_.begin(), true};
}
list_iter erase(list_iter pos) {
map_.erase(pos->first);
return vlist_.erase(pos);
}
private:
void trim() {
while (map_.size() > capacity_) {
auto last = std::prev(vlist_.end());
map_.erase(last->first);
vlist_.pop_back();
}
}
list_type vlist_;
map_type map_;
size_t capacity_;
};
// Turn a POD struct into a container key by doing bytes compare.
template <typename T>
struct BytesKey {
T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {}
BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
BytesKey(BytesKey&& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
bool operator==(const BytesKey& other) const {
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
}
};
// Compute hash according to the bytes value of T.
template <typename T>
struct BytesHash {
static_assert(std::is_standard_layout_v<T>, "T is not POD");
size_t operator()(const T& pod) const {
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < sizeof(T); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return value;
}
};
template <typename K, typename V>
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
template <typename K, typename V>
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
} // namespace mlx::core