From a5c89754c7381716522954c712f523b7879583ff Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 25 Jul 2025 12:30:50 -0700 Subject: [PATCH] Add sdpa file --- .../cuda/scaled_dot_product_attention.cu | 464 ++++++++++++++++++ 1 file changed, 464 insertions(+) create mode 100644 mlx/backend/cuda/scaled_dot_product_attention.cu diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu new file mode 100644 index 000000000..05e6fbf29 --- /dev/null +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -0,0 +1,464 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +// cudnn_frontend.h redefines this macro. +#undef CHECK_CUDA_ERROR + +#include +#include +#include + +namespace fe = cudnn_frontend; + +namespace mlx::core { + +namespace cu {} // namespace cu + +namespace { + +struct SDPACacheKey { + int device_id; + fe::DataType_t cudnn_type; + + int B; + int H; + int D; + + int qL; + int kL; + + int gqa_factor; + float scale; + + int64_t Q_strides[3]; + int64_t K_strides[3]; + int64_t V_strides[3]; + int64_t O_strides[3]; + + bool generate_stats; + bool causal_mask; +}; + +auto& sdpa_cache() { + static LRUBytesKeyCache> + cache( + /* capacity */ 128); + return cache; +} + +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 + +std::shared_ptr get_sdpa_forward_graph( + cu::CommandEncoder& encoder, + const SDPACacheKey& cache_key) { + // Check if graph has already been fully built + if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) { + return it->second; + } + + // Set up new graph + auto graph = std::make_shared(); + + graph->set_io_data_type(cache_key.cudnn_type) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D}) + .set_stride( + {cache_key.Q_strides[0], + cache_key.Q_strides[1], + cache_key.Q_strides[2], + 1})); + + int h_kv = cache_key.H / cache_key.gqa_factor; + auto K = + graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D}) + .set_stride( + {cache_key.K_strides[0], + cache_key.K_strides[1], + cache_key.V_strides[2], + 1})); + + auto V = + graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D}) + .set_stride( + {cache_key.V_strides[0], + cache_key.V_strides[1], + cache_key.V_strides[2], + 1})); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(!cache_key.generate_stats) + .set_attn_scale(cache_key.scale); + + if (cache_key.causal_mask && cache_key.qL > 1) { + sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true) + .set_uid(O_UID) + .set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D}) + .set_stride( + {cache_key.O_strides[0], + cache_key.O_strides[1], + cache_key.O_strides[2], + 1}); + + if (cache_key.generate_stats) { + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_uid(STATS_UID); + } + + // Build and Validate cudnn graph + + auto handle = encoder.device().cudnn_handle(); + + // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. + if (cudnnGetVersion() < 90600) { + auto build_status = graph->build(handle, {fe::HeurMode_t::A}); + if (!build_status.is_good()) { + throw std::runtime_error( + "Unable to build cudnn graph for attention." + " Failed with message: " + + build_status.get_message()); + } + + } else { + auto val_status = graph->validate(); + auto op_status = graph->build_operation_graph(handle); + + auto plan_stauts = + graph->create_execution_plans({cudnn_frontend::HeurMode_t::A}); + if (!plan_stauts.is_good()) { + throw std::runtime_error( + "Unable to create exec plan for cudnn attention." + " Failed with message: " + + plan_stauts.get_message()); + } + + graph->select_behavior_notes( + {cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); + + auto support_status = graph->check_support(handle); + if (!support_status.is_good()) { + throw std::runtime_error( + "No cuda graph support for cudnn attention." + " Failed with message: " + + support_status.get_message()); + } + + auto build_status = graph->build_plans(handle); + if (!build_status.is_good()) { + throw std::runtime_error( + "Unable to build cudnn graph for attention." + " Failed with message: " + + build_status.get_message()); + } + } + + auto [it, _] = sdpa_cache().emplace(cache_key, graph); + + return it->second; +} + +inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return fe::DataType_t::INT8; + case int32: + return fe::DataType_t::INT32; + case uint8: + return fe::DataType_t::UINT8; + case float16: + return fe::DataType_t::HALF; + case bfloat16: + return fe::DataType_t::BFLOAT16; + case float32: + return fe::DataType_t::FLOAT; + case float64: + return fe::DataType_t::DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in SDPA: {}.", dtype_to_string(dtype))); + } +} + +void sdpa_cudnn( + const Stream& s, + cu::CommandEncoder& encoder, + const array& q, + const array& k, + const array& v, + const float scale, + array& o, + bool do_causal_ = false) { + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + encoder.set_output_array(o); + + auto cudnn_type = dtype_to_cudnn_type(q.dtype()); + + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + int qL = q.shape(2); + int kL = k.shape(2); + + SDPACacheKey cache_key{ + /* int device_id = */ encoder.device().cuda_device(), + /* fe::DataType_t cudnn_type = */ cudnn_type, + + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, + + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, + + /* bool generate_stats = */ false, + /* bool causal_mask = */ do_causal_}; + + auto graph = get_sdpa_forward_graph(encoder, cache_key); + + int64_t workspace_size = 0; + auto workspace_status = graph->get_workspace_size(workspace_size); + if (!workspace_status.is_good()) { + throw std::runtime_error("Unable to get workspace for cudnn attention."); + } + + array workspace( + allocator::malloc(workspace_size), {int(workspace_size)}, uint8); + auto workspace_ptr = workspace.data(); + + std::unordered_map variant_pack = { + {Q_UID, const_cast(q.data())}, + {K_UID, const_cast(k.data())}, + {V_UID, const_cast(v.data())}, + {O_UID, o.data()}}; + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + + // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. + if (cudnnGetVersion() < 90600) { + auto capture = encoder.capture_context(); + auto exec_status = graph->execute(handle, variant_pack, workspace_ptr); + + if (!exec_status.is_good()) { + capture.discard = true; + throw std::runtime_error( + "Unable to execute cudnn attention." + " Failed with message: " + + exec_status.get_message()); + } + } else { + cudaGraph_t cu_graph; + cudaGraphCreate(&cu_graph, 0); + + std::unique_ptr graph_freer( + &cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); + + auto cu_graph_status = graph->populate_cuda_graph( + handle, variant_pack, workspace_ptr, cu_graph); + + if (!cu_graph_status.is_good()) { + throw std::runtime_error( + "Unable to add cuda graph for cudnn attention." + " Failed with message: " + + cu_graph_status.get_message()); + } + + encoder.add_graph_node(cu_graph); + } + + encoder.add_temporary(workspace); +} + +} // namespace + +namespace fast { + +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + Stream s) { + if (s.device == Device::cpu) { + return true; + } + + auto& cu_device = cu::device(s.device); + if (cu_device.compute_capability_major() < 8) { + return true; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); + + const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); + + const bool supported_dtype = q.dtype() == float16 || q.dtype() == bfloat16; + + const bool supported_config = supported_dtype && sdpa_supported_head_dim; + + return has_arr_mask || !supported_config; +} + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + array& out) { + nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& v_pre = inputs[2]; + auto& o = out; + + std::vector copies; + + // Define some copy functions to ensure the layout of the inputs is as + // expected. + copies.reserve(3); + auto copy_unless = [&copies, &s]( + auto predicate, const array& arr) -> const array& { + if (!predicate(arr)) { + array arr_copy = contiguous_copy_gpu(arr, s); + copies.push_back(std::move(arr_copy)); + return copies.back(); + } else { + return arr; + } + }; + + auto is_matrix_contiguous = [](const array& arr) { + return arr.strides(-1) == 1; + }; + + // We are in vector mode ie single query + if (q_pre.shape(2) <= 1) { + auto q_copy_unless = [](const array& arr) { + if (arr.flags().row_contiguous) { + return true; + } + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (shape[0] == 1 || shape[1] == 1) { + // If either the batch or head dimension is a singleton, the other can + // be transposed with the sequence dimension + auto bidx = shape[0] == 1 ? 1 : 0; + return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) && + (strides[bidx] == shape[3]); + } + return false; + }; + + auto kv_copy_unless = [](const array& arr) { + // keys and values should be copied if: + // - the last dimension is not contiguous + // - the batch and head dim are not contiguous + auto& strides = arr.strides(); + auto& shape = arr.shape(); + if (strides.back() != 1) { + return false; + } + if (shape[0] == 1 || shape[1] == 1) { + return true; + } + return (strides[0] == strides[1] * shape[1]); + }; + + const auto& q = copy_unless(q_copy_unless, q_pre); + const auto& k = copy_unless(kv_copy_unless, k_pre); + const auto& v = copy_unless(kv_copy_unless, v_pre); + + for (const auto& cp : copies) { + encoder.add_temporary(cp); + } + + // Donate the query if possible + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { + o.copy_shared_buffer(q); + } else { + o.set_data(allocator::malloc(o.nbytes())); + } + + return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); + } + + // Full attention mode + else { + const auto& q = copy_unless(is_matrix_contiguous, q_pre); + const auto& k = copy_unless(is_matrix_contiguous, k_pre); + const auto& v = copy_unless(is_matrix_contiguous, v_pre); + + int64_t str_oD = 1; + int64_t str_oH = o.shape(3); + int64_t str_oL = o.shape(1) * str_oH; + int64_t str_oB = o.shape(2) * str_oL; + size_t data_size = o.shape(0) * str_oB; + + array::Flags flags{ + /* bool contiguous = */ 1, + /* bool row_contiguous = */ 0, + /* bool col_contiguous = */ 0, + }; + + o.set_data( + allocator::malloc(o.nbytes()), + data_size, + {str_oB, str_oH, str_oL, str_oD}, + flags); + + return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); + } +} + +} // namespace fast + +} // namespace mlx::core \ No newline at end of file