mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
route to fallback (#828)
This commit is contained in:
parent
3f8b1668c4
commit
43abc402d8
64
mlx/fast.cpp
64
mlx/fast.cpp
@ -162,8 +162,8 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// K, V must have matching number of heads (n_kv_heads);
|
// K, V must have matching number of heads (n_kv_heads);
|
||||||
size_t n_q_heads = queries.shape(-3);
|
auto n_q_heads = queries.shape(-3);
|
||||||
size_t n_kv_heads = keys.shape(-3);
|
auto n_kv_heads = keys.shape(-3);
|
||||||
|
|
||||||
if (keys.shape(-3) != values.shape(-3)) {
|
if (keys.shape(-3) != values.shape(-3)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -207,53 +207,43 @@ array scaled_dot_product_attention(
|
|||||||
bool needs_mask = mask.has_value();
|
bool needs_mask = mask.has_value();
|
||||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
auto& q_tensor = inputs[0];
|
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||||
auto& k_tensor = inputs[1];
|
|
||||||
auto& v_tensor = inputs[2];
|
|
||||||
auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s);
|
|
||||||
|
|
||||||
auto tile_if_needs_repeat =
|
|
||||||
[n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array {
|
|
||||||
if (n_q_heads == n_kv_heads)
|
|
||||||
return arr;
|
|
||||||
int n_repeats = n_q_heads / n_kv_heads;
|
int n_repeats = n_q_heads / n_kv_heads;
|
||||||
constexpr const int heads_axis =
|
int B = q.shape(0);
|
||||||
1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] ->
|
int L = q.shape(2);
|
||||||
// [Batch, Heads, Sequence, Hidden]
|
auto k = inputs[1];
|
||||||
auto ret = repeat(arr, n_repeats, heads_axis, s);
|
auto v = inputs[2];
|
||||||
return ret;
|
if (n_repeats > 1) {
|
||||||
};
|
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||||
auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s);
|
k = expand_dims(k, 2, s);
|
||||||
auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s);
|
v = expand_dims(v, 2, s);
|
||||||
|
|
||||||
// dim check on k, v; repeat if untiled, since naive matmul will have
|
|
||||||
// dim mismatch for GQA (MQA could make use of broadcast)
|
|
||||||
auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s);
|
|
||||||
auto s_tensor = matmul(q_scaled, k_transposed, s);
|
|
||||||
if (needs_mask) {
|
|
||||||
auto mask_tensor = inputs[3];
|
|
||||||
s_tensor = add(s_tensor, mask_tensor, s);
|
|
||||||
}
|
}
|
||||||
auto p = astype(
|
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||||
softmax(astype(s_tensor, float32, s), std::vector<int>{-1}, s),
|
if (needs_mask) {
|
||||||
|
scores = add(scores, inputs[3], s);
|
||||||
|
}
|
||||||
|
scores = astype(
|
||||||
|
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
|
||||||
final_type,
|
final_type,
|
||||||
s);
|
s);
|
||||||
auto out_tensor = matmul(p, v_tensor_tiled, s);
|
auto out = matmul(scores, v, s);
|
||||||
return std::vector<array>{out_tensor};
|
if (n_repeats > 1) {
|
||||||
|
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||||
|
}
|
||||||
|
return std::vector<array>{out};
|
||||||
};
|
};
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
|
|
||||||
// current implementation use case: batch size 1, query sequence length 1, no
|
|
||||||
// mask. Likewise, requires head_dim == 128
|
|
||||||
constexpr const int supported_head_dim = 128;
|
constexpr const int supported_head_dim = 128;
|
||||||
const size_t query_head_dim = q.shape(-1);
|
const size_t query_head_dim = q.shape(-1);
|
||||||
const size_t query_sequence_length = q.shape(2);
|
const size_t query_sequence_length = q.shape(2);
|
||||||
bool implementation_supports_use_case = batch_dim == 1 &&
|
bool implementation_supports_use_case = batch_dim == 1 &&
|
||||||
query_sequence_length == 1 && !mask.has_value() &&
|
query_sequence_length == 1 && !mask.has_value() &&
|
||||||
query_head_dim == supported_head_dim && final_type != bfloat16;
|
query_head_dim == supported_head_dim && final_type != bfloat16 &&
|
||||||
|
stream.device == Device::gpu;
|
||||||
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
// TODO, update routing conditions post further tuning
|
||||||
|
implementation_supports_use_case &= false;
|
||||||
|
if (implementation_supports_use_case) {
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
final_type,
|
final_type,
|
||||||
|
Loading…
Reference in New Issue
Block a user