mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add cache
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -20,11 +22,44 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
cudnnDataType_t cudnn_type;
|
||||
std::array<int, MAX_NDIM> input_shape;
|
||||
std::array<int, MAX_NDIM> filter_shape;
|
||||
std::array<int, MAX_NDIM> padding_lo;
|
||||
std::array<int, MAX_NDIM> padding_hi;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
std::array<int, MAX_NDIM> dilation;
|
||||
int groups;
|
||||
uint8_t input_alignment;
|
||||
uint8_t filter_alignment;
|
||||
uint8_t output_alignment;
|
||||
};
|
||||
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
|
||||
/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result;
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
@@ -57,8 +92,8 @@ inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
inline int64_t get_alignment(const array& x) {
|
||||
int64_t alignment = 1;
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
@@ -126,22 +161,12 @@ cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ManagedOpaqueDescriptor& config,
|
||||
const std::string& op_graph_tag,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
|
||||
int64_t workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(
|
||||
allocator::malloc(workspace_size),
|
||||
{static_cast<int>(workspace_size)},
|
||||
int8);
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
@@ -158,28 +183,34 @@ bool execute_plan(
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
encoder.device().cudnn_handle(),
|
||||
plan.get_raw_desc(),
|
||||
variantPack.get_raw_desc()) != CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_completed_handler([plan = std::move(plan)]() {});
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool execute_plans(
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const ConvCacheKey& cache_key,
|
||||
const std::string& op_graph_tag,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
if (execute_plan(encoder, config, op_graph_tag, in, wt, out)) {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, in, wt, out)) {
|
||||
conv_cache().emplace(cache_key, std::move(plan));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException&) {
|
||||
@@ -219,12 +250,36 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// TODO: Searching a working execution plan is expensive, add cache.
|
||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
backend_type,
|
||||
cudnn_type,
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
groups_,
|
||||
get_alignment(in),
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
auto it = conv_cache().find(cache_key);
|
||||
if (it != conv_cache().end()) {
|
||||
if (!execute_plan(encoder, it->second, in, wt, out)) {
|
||||
throw std::runtime_error("Cached convolution plan failed to execute.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Build operation graph.
|
||||
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: dtype_to_cudnn_type(in.dtype());
|
||||
: cudnn_type;
|
||||
|
||||
auto stride = convert_vector<int64_t>(kernel_strides_);
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
@@ -241,7 +296,6 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', in))
|
||||
.setwDesc(build_tensor('w', wt))
|
||||
@@ -257,12 +311,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
|
||||
if (execute_plans(encoder, configs, op_graph.getTag(), in, wt, out)) {
|
||||
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Unable to find an engine for convolution.");
|
||||
|
||||
144
mlx/backend/cuda/lru_cache.h
Normal file
144
mlx/backend/cuda/lru_cache.h
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <
|
||||
typename K,
|
||||
typename V,
|
||||
template <typename...> typename M = std::unordered_map>
|
||||
class LRUCache {
|
||||
public:
|
||||
using value_type = std::pair<K, V>;
|
||||
using list_type = std::list<value_type>;
|
||||
using list_iter = typename list_type::iterator;
|
||||
using map_type = M<K, list_iter>;
|
||||
|
||||
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
||||
|
||||
size_t size() const {
|
||||
return map_.size();
|
||||
}
|
||||
size_t capacity() const {
|
||||
return capacity_;
|
||||
}
|
||||
bool empty() const {
|
||||
return vlist_.empty();
|
||||
}
|
||||
|
||||
void resize(size_t new_capacity) {
|
||||
capacity_ = new_capacity;
|
||||
trim();
|
||||
}
|
||||
|
||||
auto begin() {
|
||||
return vlist_.begin();
|
||||
}
|
||||
auto begin() const {
|
||||
return vlist_.begin();
|
||||
}
|
||||
auto end() {
|
||||
return vlist_.end();
|
||||
}
|
||||
auto end() const {
|
||||
return vlist_.end();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
map_.clear();
|
||||
vlist_.clear();
|
||||
}
|
||||
|
||||
list_iter find(const K& key) {
|
||||
auto it = map_.find(key);
|
||||
if (it == map_.end())
|
||||
return end();
|
||||
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<list_iter, bool> emplace(const K& key, V value) {
|
||||
auto it = map_.find(key);
|
||||
if (it != map_.end()) {
|
||||
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||
return {it->second, false};
|
||||
}
|
||||
|
||||
vlist_.emplace_front(key, std::move(value));
|
||||
map_[key] = vlist_.begin();
|
||||
|
||||
trim();
|
||||
|
||||
return {vlist_.begin(), true};
|
||||
}
|
||||
|
||||
list_iter erase(list_iter pos) {
|
||||
map_.erase(pos->first);
|
||||
return vlist_.erase(pos);
|
||||
}
|
||||
|
||||
private:
|
||||
void trim() {
|
||||
while (map_.size() > capacity_) {
|
||||
auto last = std::prev(vlist_.end());
|
||||
map_.erase(last->first);
|
||||
vlist_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
list_type vlist_;
|
||||
map_type map_;
|
||||
size_t capacity_;
|
||||
};
|
||||
|
||||
// Turn a POD struct into a container key by doing bytes compare.
|
||||
template <typename T>
|
||||
struct BytesKey {
|
||||
T pod;
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||
|
||||
BytesKey(const BytesKey& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
}
|
||||
|
||||
BytesKey(BytesKey&& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
}
|
||||
|
||||
bool operator==(const BytesKey& other) const {
|
||||
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
|
||||
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
|
||||
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Compute hash according to the bytes value of T.
|
||||
template <typename T>
|
||||
struct BytesHash {
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
size_t operator()(const T& pod) const {
|
||||
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
|
||||
uint32_t value = 0x811C9DC5;
|
||||
for (int i = 0; i < sizeof(T); ++i) {
|
||||
value ^= ptr[i];
|
||||
value *= 0x01000193;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K, typename V>
|
||||
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
|
||||
|
||||
template <typename K, typename V>
|
||||
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
|
||||
|
||||
} // namespace mlx::core
|
||||
Reference in New Issue
Block a user