diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 484602c6d..eb1f248d5 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -8,19 +8,13 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" +#include "mlx/transforms_impl.h" -// cudnn_frontend.h redefines this macro. -#undef CHECK_CUDA_ERROR - -#include -#include #include #include #include -namespace fe = cudnn_frontend; - namespace mlx::core { namespace cu { @@ -645,294 +639,6 @@ void sdpa_vector_fallback( } } -struct SDPACacheKey { - int device_id; - fe::DataType_t cudnn_type; - - int B; - int H; - int D; - - int qL; - int kL; - - int gqa_factor; - float scale; - - int64_t Q_strides[3]; - int64_t K_strides[3]; - int64_t V_strides[3]; - int64_t O_strides[3]; - - bool generate_stats; - bool causal_mask; -}; - -auto& sdpa_cache() { - static LRUBytesKeyCache> - cache( - /* capacity */ 128); - return cache; -} - -#define Q_UID 1 -#define K_UID 2 -#define V_UID 3 -#define O_UID 4 -#define STATS_UID 5 - -std::shared_ptr get_sdpa_forward_graph( - cu::CommandEncoder& encoder, - const SDPACacheKey& cache_key) { - // Check if graph has already been fully built - if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) { - return it->second; - } - - // Set up new graph - auto graph = std::make_shared(); - - graph->set_io_data_type(cache_key.cudnn_type) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto Q = graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Q") - .set_uid(Q_UID) - .set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D}) - .set_stride( - {cache_key.Q_strides[0], - cache_key.Q_strides[1], - cache_key.Q_strides[2], - 1})); - - int h_kv = cache_key.H / cache_key.gqa_factor; - auto K = - graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_uid(K_UID) - .set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D}) - .set_stride( - {cache_key.K_strides[0], - cache_key.K_strides[1], - cache_key.V_strides[2], - 1})); - - auto V = - graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_uid(V_UID) - .set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D}) - .set_stride( - {cache_key.V_strides[0], - cache_key.V_strides[1], - cache_key.V_strides[2], - 1})); - - auto sdpa_options = fe::graph::SDPA_attributes() - .set_name("flash_attention") - .set_is_inference(!cache_key.generate_stats) - .set_attn_scale(cache_key.scale); - - if (cache_key.causal_mask && cache_key.qL > 1) { - sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT) - .set_diagonal_band_right_bound(0); - } - - auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); - - O->set_output(true) - .set_uid(O_UID) - .set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D}) - .set_stride( - {cache_key.O_strides[0], - cache_key.O_strides[1], - cache_key.O_strides[2], - 1}); - - if (cache_key.generate_stats) { - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_uid(STATS_UID); - } - - // Build and Validate cudnn graph - - auto handle = encoder.device().cudnn_handle(); - - // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. - if (cudnnGetVersion() < 90600) { - auto build_status = graph->build(handle, {fe::HeurMode_t::A}); - if (!build_status.is_good()) { - throw std::runtime_error( - "Unable to build cudnn graph for attention." - " Failed with message: " + - build_status.get_message()); - } - - } else { - auto val_status = graph->validate(); - auto op_status = graph->build_operation_graph(handle); - - auto plan_stauts = - graph->create_execution_plans({cudnn_frontend::HeurMode_t::A}); - if (!plan_stauts.is_good()) { - throw std::runtime_error( - "Unable to create exec plan for cudnn attention." - " Failed with message: " + - plan_stauts.get_message()); - } - - graph->select_behavior_notes( - {cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); - - auto support_status = graph->check_support(handle); - if (!support_status.is_good()) { - throw std::runtime_error( - "No cuda graph support for cudnn attention." - " Failed with message: " + - support_status.get_message()); - } - - auto build_status = graph->build_plans(handle); - if (!build_status.is_good()) { - throw std::runtime_error( - "Unable to build cudnn graph for attention." - " Failed with message: " + - build_status.get_message()); - } - } - - auto [it, _] = sdpa_cache().emplace(cache_key, graph); - - return it->second; -} - -inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { - switch (dtype) { - case int8: - return fe::DataType_t::INT8; - case int32: - return fe::DataType_t::INT32; - case uint8: - return fe::DataType_t::UINT8; - case float16: - return fe::DataType_t::HALF; - case bfloat16: - return fe::DataType_t::BFLOAT16; - case float32: - return fe::DataType_t::FLOAT; - case float64: - return fe::DataType_t::DOUBLE; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in SDPA: {}.", dtype_to_string(dtype))); - } -} - -void sdpa_cudnn( - const Stream& s, - cu::CommandEncoder& encoder, - const array& q, - const array& k, - const array& v, - const float scale, - array& o, - bool do_causal_ = false) { - encoder.set_input_array(q); - encoder.set_input_array(k); - encoder.set_input_array(v); - encoder.set_output_array(o); - - auto cudnn_type = dtype_to_cudnn_type(q.dtype()); - - int B = q.shape(0); - int H = q.shape(1); - int D = q.shape(3); - int gqa_factor = q.shape(1) / k.shape(1); - - int qL = q.shape(2); - int kL = k.shape(2); - - SDPACacheKey cache_key{ - /* int device_id = */ encoder.device().cuda_device(), - /* fe::DataType_t cudnn_type = */ cudnn_type, - - /* int B = */ B, - /* int H = */ H, - /* int D = */ D, - - /* int qL = */ qL, - /* int kL = */ kL, - - /* int gqa_factor = */ gqa_factor, - /* float scale = */ scale, - - /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, - /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, - /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}, - - /* bool generate_stats = */ false, - /* bool causal_mask = */ do_causal_}; - - auto graph = get_sdpa_forward_graph(encoder, cache_key); - - int64_t workspace_size = 0; - auto workspace_status = graph->get_workspace_size(workspace_size); - if (!workspace_status.is_good()) { - throw std::runtime_error("Unable to get workspace for cudnn attention."); - } - - array workspace( - allocator::malloc(workspace_size), {int(workspace_size)}, uint8); - auto workspace_ptr = workspace.data(); - - std::unordered_map variant_pack = { - {Q_UID, const_cast(q.data())}, - {K_UID, const_cast(k.data())}, - {V_UID, const_cast(v.data())}, - {O_UID, o.data()}}; - - auto handle = encoder.device().cudnn_handle(); - cudnnSetStream(handle, encoder.stream()); - - // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. - if (cudnnGetVersion() < 90600) { - auto capture = encoder.capture_context(); - auto exec_status = graph->execute(handle, variant_pack, workspace_ptr); - - if (!exec_status.is_good()) { - capture.discard = true; - throw std::runtime_error( - "Unable to execute cudnn attention." - " Failed with message: " + - exec_status.get_message()); - } - } else { - cudaGraph_t cu_graph; - cudaGraphCreate(&cu_graph, 0); - - std::unique_ptr graph_freer( - &cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); - - auto cu_graph_status = graph->populate_cuda_graph( - handle, variant_pack, workspace_ptr, cu_graph); - - if (!cu_graph_status.is_good()) { - throw std::runtime_error( - "Unable to add cuda graph for cudnn attention." - " Failed with message: " + - cu_graph_status.get_message()); - } - - encoder.add_graph_node(cu_graph); - } - - encoder.add_temporary(workspace); -} - } // namespace namespace fast { @@ -945,6 +651,9 @@ bool ScaledDotProductAttention::use_fallback( bool has_arr_mask, bool do_causal, Stream s) { + if (detail::in_grad_tracing()) { + return true; + } if (s.device == Device::cpu) { return true; } @@ -960,15 +669,7 @@ bool ScaledDotProductAttention::use_fallback( const bool supported_vector_config = sdpa_supported_head_dim && query_sequence_length < 4; - auto& cu_device = cu::device(s.device); - - const bool supported_matrix_config = query_sequence_length > 4 && - cu_device.compute_capability_major() >= 8 && - query_sequence_length == key_sequence_length && - (q.dtype() == float16 || q.dtype() == bfloat16); - - const bool supported_config = - (supported_matrix_config || supported_vector_config); + const bool supported_config = supported_vector_config; return has_arr_mask || !supported_config; } @@ -1002,10 +703,6 @@ void ScaledDotProductAttention::eval_gpu( } }; - auto is_matrix_contiguous = [](const array& arr) { - return arr.strides(-1) == 1; - }; - // We are in vector mode ie single query if (q_pre.shape(2) < 4) { auto q_copy_unless = [](const array& arr) { @@ -1059,7 +756,7 @@ void ScaledDotProductAttention::eval_gpu( array::Flags flags{ /* bool contiguous = */ 1, - /* bool row_contiguous = */ 0, + /* bool row_contiguous = */ o.shape(2) == 1, /* bool col_contiguous = */ 0, }; @@ -1073,35 +770,9 @@ void ScaledDotProductAttention::eval_gpu( return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); } - // Full attention mode + // Full attention mode should never reach here else { - const auto& q = copy_unless(is_matrix_contiguous, q_pre); - const auto& k = copy_unless(is_matrix_contiguous, k_pre); - const auto& v = copy_unless(is_matrix_contiguous, v_pre); - - for (const auto& cp : copies) { - encoder.add_temporary(cp); - } - - int64_t str_oD = 1; - int64_t str_oH = o.shape(3); - int64_t str_oL = o.shape(1) * str_oH; - int64_t str_oB = o.shape(2) * str_oL; - size_t data_size = o.shape(0) * str_oB; - - array::Flags flags{ - /* bool contiguous = */ 1, - /* bool row_contiguous = */ 0, - /* bool col_contiguous = */ 0, - }; - - o.set_data( - allocator::malloc(o.nbytes()), - data_size, - {str_oB, str_oH, str_oL, str_oD}, - flags); - - return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_); + throw std::runtime_error("Doesn't support matrix yet."); } }