route to fallback (#828)

This commit is contained in:
Awni Hannun 2024-03-13 19:56:04 -07:00 committed by GitHub
parent 3f8b1668c4
commit 43abc402d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -162,8 +162,8 @@ array scaled_dot_product_attention(
}
// K, V must have matching number of heads (n_kv_heads);
size_t n_q_heads = queries.shape(-3);
size_t n_kv_heads = keys.shape(-3);
auto n_q_heads = queries.shape(-3);
auto n_kv_heads = keys.shape(-3);
if (keys.shape(-3) != values.shape(-3)) {
std::ostringstream msg;
@ -207,53 +207,43 @@ array scaled_dot_product_attention(
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
auto& q_tensor = inputs[0];
auto& k_tensor = inputs[1];
auto& v_tensor = inputs[2];
auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s);
auto tile_if_needs_repeat =
[n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array {
if (n_q_heads == n_kv_heads)
return arr;
int n_repeats = n_q_heads / n_kv_heads;
constexpr const int heads_axis =
1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] ->
// [Batch, Heads, Sequence, Hidden]
auto ret = repeat(arr, n_repeats, heads_axis, s);
return ret;
};
auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s);
auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s);
// dim check on k, v; repeat if untiled, since naive matmul will have
// dim mismatch for GQA (MQA could make use of broadcast)
auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s);
auto s_tensor = matmul(q_scaled, k_transposed, s);
if (needs_mask) {
auto mask_tensor = inputs[3];
s_tensor = add(s_tensor, mask_tensor, s);
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0);
int L = q.shape(2);
auto k = inputs[1];
auto v = inputs[2];
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
k = expand_dims(k, 2, s);
v = expand_dims(v, 2, s);
}
auto p = astype(
softmax(astype(s_tensor, float32, s), std::vector<int>{-1}, s),
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (needs_mask) {
scores = add(scores, inputs[3], s);
}
scores = astype(
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
final_type,
s);
auto out_tensor = matmul(p, v_tensor_tiled, s);
return std::vector<array>{out_tensor};
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
}
return std::vector<array>{out};
};
auto stream = to_stream(s);
// current implementation use case: batch size 1, query sequence length 1, no
// mask. Likewise, requires head_dim == 128
constexpr const int supported_head_dim = 128;
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = batch_dim == 1 &&
query_sequence_length == 1 && !mask.has_value() &&
query_head_dim == supported_head_dim && final_type != bfloat16;
if (stream.device == Device::gpu && implementation_supports_use_case) {
query_head_dim == supported_head_dim && final_type != bfloat16 &&
stream.device == Device::gpu;
// TODO, update routing conditions post further tuning
implementation_supports_use_case &= false;
if (implementation_supports_use_case) {
auto out = array(
out_shape,
final_type,