mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Separate sdpa kernels in another file * Initial support for cuDNN SDPA * Diable a few corner cases * Remove scaled_dot_product_attention.h * Use cuDNN attention for prefilling * cuDNN SDPA requires Ampere and later * Address reviews * Do contiguous copy of inputs
This commit is contained in:
@@ -44,6 +44,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
|
||||
@@ -281,7 +281,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
BytesKey<ConvCacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
vector_key(in.shape()),
|
||||
|
||||
@@ -44,13 +44,13 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||
// 1. The rest of array is filled with 0.
|
||||
// 2. This util can be used in .cpp files.
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
|
||||
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::array<T, NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -135,12 +135,19 @@ class LRUCache {
|
||||
};
|
||||
|
||||
// Turn a POD struct into a container key by doing bytes compare.
|
||||
//
|
||||
// Usage:
|
||||
// BytesKey<MyKey> key;
|
||||
// key.pod = { ... };
|
||||
template <typename T>
|
||||
struct BytesKey {
|
||||
T pod;
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||
BytesKey() {
|
||||
// Make sure the paddings between members are filled with 0.
|
||||
memset(&pod, 0, sizeof(T));
|
||||
}
|
||||
|
||||
BytesKey(const BytesKey& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
|
||||
321
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
321
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
@@ -0,0 +1,321 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
do { \
|
||||
auto error = cmd; \
|
||||
if (!error.is_good()) { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
tensor->set_uid(uid)
|
||||
.set_dim({x.shape().begin(), x.shape().end()})
|
||||
.set_stride(normalized_strides(x));
|
||||
}
|
||||
|
||||
array prepare_sdpa_input(const array& x, Stream s) {
|
||||
// SDPA kernel's requirements on inputs:
|
||||
// 1. last dim's stride be 1;
|
||||
// 2. pointer be aligned.
|
||||
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
std::array<int, QKV_NDIM> q_shape;
|
||||
std::array<int, QKV_NDIM> k_shape;
|
||||
std::array<int, QKV_NDIM> v_shape;
|
||||
std::array<int64_t, QKV_NDIM> q_strides;
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
};
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
enum UIDS {
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
SCALE,
|
||||
O,
|
||||
};
|
||||
|
||||
fe::graph::Graph build_sdpa_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const array& o) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto sdpa_options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(false);
|
||||
|
||||
auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options);
|
||||
o_->set_output(true);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
if (!enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// cuDNN SDPA requires Ampere and later.
|
||||
if (cu::device(s.device).compute_capability_major() < 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// TODO: Support array masks.
|
||||
if (!do_causal) {
|
||||
return false;
|
||||
}
|
||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// D_qk and D_v must be a multiple of 8 with maximum value 128.
|
||||
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
|
||||
(v.shape(-1) > 128)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Dtype dtype = q.dtype();
|
||||
return dtype == float16 || dtype == bfloat16;
|
||||
}
|
||||
|
||||
void sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder.stream()));
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
// Search cache.
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(q.dtype()),
|
||||
vector_key<QKV_NDIM>(q.shape()),
|
||||
vector_key<QKV_NDIM>(k.shape()),
|
||||
vector_key<QKV_NDIM>(v.shape()),
|
||||
vector_key<QKV_NDIM>(q.strides()),
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
};
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
it =
|
||||
sdpa_cache()
|
||||
.emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o))
|
||||
.first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder.stream()),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
}
|
||||
|
||||
// Defined in scaled_dot_product_attention.cu file.
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal);
|
||||
void sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks,
|
||||
Stream s);
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) {
|
||||
if (has_sinks_) {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
|
||||
} else {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, do_causal_, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -6,10 +6,6 @@
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
@@ -663,23 +659,13 @@ void sdpa_vector_fallback(
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool do_causal) {
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
@@ -691,29 +677,24 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
const bool supported_config = supported_vector_config;
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
return supported_vector_config && !has_arr_mask;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
void sdpa_vector(
|
||||
const array& q_pre,
|
||||
const array& k_pre,
|
||||
const array& v_pre,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks_pre,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(inputs.size());
|
||||
copies.reserve(4);
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -731,8 +712,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
if (sinks_pre) {
|
||||
sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
@@ -798,8 +779,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
@@ -808,6 +788,4 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user