mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
[WIP] 2 pass sdpav
This commit is contained in:
@@ -4,6 +4,8 @@
|
|||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
|||||||
@@ -15,14 +15,632 @@
|
|||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
namespace fe = cudnn_frontend;
|
namespace fe = cudnn_frontend;
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {} // namespace cu
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
#define PRAGMA_LOOP_UNROLL #pragma unroll
|
||||||
|
|
||||||
|
struct AttnParams {
|
||||||
|
int B;
|
||||||
|
int H;
|
||||||
|
int D;
|
||||||
|
|
||||||
|
int qL;
|
||||||
|
int kL;
|
||||||
|
|
||||||
|
int gqa_factor;
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
int64_t Q_strides[3];
|
||||||
|
int64_t K_strides[3];
|
||||||
|
int64_t V_strides[3];
|
||||||
|
int64_t O_strides[3];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_1pass(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -INFINITY;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
U exp_score = exp2f(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = max_scores[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = exp2f(max_score - new_max);
|
||||||
|
sum_exp_score =
|
||||||
|
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||||
|
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_1(
|
||||||
|
const T* Q,
|
||||||
|
const T* K,
|
||||||
|
const T* V,
|
||||||
|
float* partials,
|
||||||
|
float* sums,
|
||||||
|
float* maxs,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 8;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
|
||||||
|
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U q[v_per_thread];
|
||||||
|
U k[v_per_thread];
|
||||||
|
U o[v_per_thread];
|
||||||
|
|
||||||
|
__shared__ U outputs[BD][BN + 1];
|
||||||
|
__shared__ U max_scores[BN];
|
||||||
|
__shared__ U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
const U scale_log2 = params.scale; // * 1.44269504089f;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = 0; // blockIdx.z / blocks;
|
||||||
|
const int block_idx = blockIdx.z % blocks;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||||
|
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
const int kv_seq_idx = block_idx * BN + warp_idx;
|
||||||
|
|
||||||
|
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||||
|
head_idx * params.Q_strides[1] + // Head
|
||||||
|
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
K += batch_idx * params.K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.K_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||||
|
|
||||||
|
V += batch_idx * params.V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params.V_strides[1] + // Head
|
||||||
|
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s + // Sequence
|
||||||
|
block_idx; // Block
|
||||||
|
|
||||||
|
partials += p_offset * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -1e9;
|
||||||
|
U sum_exp_score = 0.f;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||||
|
bool use_key = true;
|
||||||
|
if constexpr (do_causal) {
|
||||||
|
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_key) {
|
||||||
|
// Read the key
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
k[j] = K[v_per_thread * lane_idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0.f;
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
score += q[j] * k[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp sum
|
||||||
|
score = cg::reduce(warp, score, cg::plus<U>());
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = expf(max_score - new_max);
|
||||||
|
U exp_score = expf(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
|
o[j] = o[j] * factor +
|
||||||
|
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
K += inner_k_stride;
|
||||||
|
V += inner_v_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
max_scores[warp_idx] = max_score;
|
||||||
|
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = expf(max_score - new_max);
|
||||||
|
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||||
|
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||||
|
|
||||||
|
// Write the sum and new max
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
sums[0] = sum_exp_score;
|
||||||
|
maxs[0] = new_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i] * expf(max_scores[warp_idx] - new_max);
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
U ot = outputs[lane_idx][0];
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int j = 1; j < BN; j++) {
|
||||||
|
ot += outputs[lane_idx][0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// o[i] = ot;
|
||||||
|
partials[v_per_thread * lane_idx + i] = ot;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// if(warp_idx == 0) {
|
||||||
|
// PRAGMA_LOOP_UNROLL
|
||||||
|
// for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
// partials[v_per_thread * lane_idx + i] = o[i];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool do_causal, int D>
|
||||||
|
__global__ void kernel_sdpav_2pass_2(
|
||||||
|
const float* partials,
|
||||||
|
const float* sums,
|
||||||
|
const float* maxs,
|
||||||
|
T* O,
|
||||||
|
__grid_constant__ const AttnParams params) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 32;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
|
constexpr int v_per_thread = D / BD;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
U o[v_per_thread];
|
||||||
|
__shared__ U outputs[BN][BD + 1];
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
// Adjust to thread block and thread
|
||||||
|
const int batch_idx = blockIdx.z;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int q_seq_idx = blockIdx.y;
|
||||||
|
|
||||||
|
const int p_stride_s = blocks;
|
||||||
|
const int p_stride_h = params.qL * p_stride_s;
|
||||||
|
const int p_stride_b = params.H * p_stride_h;
|
||||||
|
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||||
|
head_idx * p_stride_h + // Head
|
||||||
|
q_seq_idx * p_stride_s; // Sequence
|
||||||
|
|
||||||
|
partials += p_offset * D + warp_idx * D;
|
||||||
|
sums += p_offset;
|
||||||
|
maxs += p_offset;
|
||||||
|
|
||||||
|
O += batch_idx * params.O_strides[0] + // Batch
|
||||||
|
head_idx * params.O_strides[1] + // Head
|
||||||
|
q_seq_idx * params.O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
U max_score = maxs[lane_idx];
|
||||||
|
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||||
|
U factor = expf(max_score - new_max);
|
||||||
|
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||||
|
// sum_exp_score = __frcp_rn(sum_exp_score);
|
||||||
|
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
o[i] = partials[v_per_thread * lane_idx + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
outputs[lane_idx][warp_idx] = o[i];
|
||||||
|
block.sync();
|
||||||
|
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||||
|
o[i] = cg::reduce(warp, ot, cg::plus<U>()) / sum_exp_score;
|
||||||
|
block.sync();
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (lane_idx == 0) {
|
||||||
|
PRAGMA_LOOP_UNROLL
|
||||||
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
|
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_headdim(int n, F&& f) {
|
||||||
|
switch (n) {
|
||||||
|
case 64:
|
||||||
|
f(std::integral_constant<int, 64>{});
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
f(std::integral_constant<int, 96>{});
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
f(std::integral_constant<int, 128>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_1pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
auto kernel = cu::kernel_sdpav_1pass<DataType, do_causal(), headdim()>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_2pass_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
cu::AttnParams params{
|
||||||
|
/* int B = */ q.shape(0),
|
||||||
|
/* int H = */ q.shape(1),
|
||||||
|
/* int D = */ q.shape(3),
|
||||||
|
|
||||||
|
/* int qL = */ q.shape(2),
|
||||||
|
/* int kL = */ k.shape(2),
|
||||||
|
|
||||||
|
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
// Allocate the intermediates
|
||||||
|
int blocks = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.reserve(o.ndim() + 1);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
|
||||||
|
intermediate_shape.push_back(blocks);
|
||||||
|
intermediate_shape.push_back(o.shape().back());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||||
|
intermediate_shape.pop_back();
|
||||||
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||||
|
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.add_temporary(sums);
|
||||||
|
encoder.add_temporary(maxs);
|
||||||
|
|
||||||
|
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||||
|
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||||
|
dispatch_headdim(params.D, [&](auto headdim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_2pass_1<DataType, do_causal(), headdim()>;
|
||||||
|
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
encoder.set_output_array(sums);
|
||||||
|
encoder.set_output_array(maxs);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B * 32);
|
||||||
|
dim3 block_dim(8 * 32, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
q.data<DataType>(),
|
||||||
|
k.data<DataType>(),
|
||||||
|
v.data<DataType>(),
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto kernel =
|
||||||
|
cu::kernel_sdpav_2pass_2<DataType, do_causal(), headdim()>;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_input_array(sums);
|
||||||
|
encoder.set_input_array(maxs);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
|
||||||
|
dim3 grid_dim(params.H, params.qL, params.B);
|
||||||
|
dim3 block_dim(1024, 1, 1);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
intermediate.data<float>(),
|
||||||
|
sums.data<float>(),
|
||||||
|
maxs.data<float>(),
|
||||||
|
o.data<DataType>(),
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_vector_fallback(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_ = false) {
|
||||||
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
if (false && kL > 1024) {
|
||||||
|
return sdpa_vector_2pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
} else {
|
||||||
|
return sdpa_vector_1pass_fallback(
|
||||||
|
s, encoder, q, k, v, scale, o, do_causal_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct SDPACacheKey {
|
struct SDPACacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
fe::DataType_t cudnn_type;
|
fe::DataType_t cudnn_type;
|
||||||
@@ -67,8 +685,6 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
nvtx3::scoped_range r("get_sdpa_forward_graph");
|
|
||||||
|
|
||||||
// Set up new graph
|
// Set up new graph
|
||||||
auto graph = std::make_shared<fe::graph::Graph>();
|
auto graph = std::make_shared<fe::graph::Graph>();
|
||||||
|
|
||||||
@@ -143,8 +759,6 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
|||||||
|
|
||||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||||
if (cudnnGetVersion() < 90600) {
|
if (cudnnGetVersion() < 90600) {
|
||||||
nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building");
|
|
||||||
|
|
||||||
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
||||||
if (!build_status.is_good()) {
|
if (!build_status.is_good()) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -331,11 +945,6 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& cu_device = cu::device(s.device);
|
|
||||||
if (cu_device.compute_capability_major() < 8) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int value_head_dim = v.shape(-1);
|
const int value_head_dim = v.shape(-1);
|
||||||
const int query_head_dim = q.shape(-1);
|
const int query_head_dim = q.shape(-1);
|
||||||
const int query_sequence_length = q.shape(2);
|
const int query_sequence_length = q.shape(2);
|
||||||
@@ -344,11 +953,7 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool supported_dtype = q.dtype() == float16 || q.dtype() == bfloat16;
|
return has_arr_mask || !sdpa_supported_head_dim;
|
||||||
|
|
||||||
const bool supported_config = supported_dtype && sdpa_supported_head_dim;
|
|
||||||
|
|
||||||
return has_arr_mask || !supported_config;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
@@ -432,7 +1037,8 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
o.set_data(allocator::malloc(o.nbytes()));
|
o.set_data(allocator::malloc(o.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
|
// return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode
|
// Full attention mode
|
||||||
|
|||||||
Reference in New Issue
Block a user