mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
route to fallback (#828)
This commit is contained in:
parent
3f8b1668c4
commit
43abc402d8
66
mlx/fast.cpp
66
mlx/fast.cpp
@ -162,8 +162,8 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
|
||||
// K, V must have matching number of heads (n_kv_heads);
|
||||
size_t n_q_heads = queries.shape(-3);
|
||||
size_t n_kv_heads = keys.shape(-3);
|
||||
auto n_q_heads = queries.shape(-3);
|
||||
auto n_kv_heads = keys.shape(-3);
|
||||
|
||||
if (keys.shape(-3) != values.shape(-3)) {
|
||||
std::ostringstream msg;
|
||||
@ -207,53 +207,43 @@ array scaled_dot_product_attention(
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& q_tensor = inputs[0];
|
||||
auto& k_tensor = inputs[1];
|
||||
auto& v_tensor = inputs[2];
|
||||
auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s);
|
||||
|
||||
auto tile_if_needs_repeat =
|
||||
[n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array {
|
||||
if (n_q_heads == n_kv_heads)
|
||||
return arr;
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
constexpr const int heads_axis =
|
||||
1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] ->
|
||||
// [Batch, Heads, Sequence, Hidden]
|
||||
auto ret = repeat(arr, n_repeats, heads_axis, s);
|
||||
return ret;
|
||||
};
|
||||
auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s);
|
||||
auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s);
|
||||
|
||||
// dim check on k, v; repeat if untiled, since naive matmul will have
|
||||
// dim mismatch for GQA (MQA could make use of broadcast)
|
||||
auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s);
|
||||
auto s_tensor = matmul(q_scaled, k_transposed, s);
|
||||
if (needs_mask) {
|
||||
auto mask_tensor = inputs[3];
|
||||
s_tensor = add(s_tensor, mask_tensor, s);
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
int B = q.shape(0);
|
||||
int L = q.shape(2);
|
||||
auto k = inputs[1];
|
||||
auto v = inputs[2];
|
||||
if (n_repeats > 1) {
|
||||
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||
k = expand_dims(k, 2, s);
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
auto p = astype(
|
||||
softmax(astype(s_tensor, float32, s), std::vector<int>{-1}, s),
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[3], s);
|
||||
}
|
||||
scores = astype(
|
||||
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
|
||||
final_type,
|
||||
s);
|
||||
auto out_tensor = matmul(p, v_tensor_tiled, s);
|
||||
return std::vector<array>{out_tensor};
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
}
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
|
||||
// current implementation use case: batch size 1, query sequence length 1, no
|
||||
// mask. Likewise, requires head_dim == 128
|
||||
constexpr const int supported_head_dim = 128;
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
bool implementation_supports_use_case = batch_dim == 1 &&
|
||||
query_sequence_length == 1 && !mask.has_value() &&
|
||||
query_head_dim == supported_head_dim && final_type != bfloat16;
|
||||
|
||||
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
||||
query_head_dim == supported_head_dim && final_type != bfloat16 &&
|
||||
stream.device == Device::gpu;
|
||||
// TODO, update routing conditions post further tuning
|
||||
implementation_supports_use_case &= false;
|
||||
if (implementation_supports_use_case) {
|
||||
auto out = array(
|
||||
out_shape,
|
||||
final_type,
|
||||
|
Loading…
Reference in New Issue
Block a user