mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
route to fallback for bfloat (#794)
This commit is contained in:
parent
1074674e32
commit
afd5274049
10
mlx/fast.cpp
10
mlx/fast.cpp
@ -106,7 +106,6 @@ array rope(
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
}
|
||||
};
|
||||
// TODO change to condition for using custom prim
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||
return array(
|
||||
@ -183,6 +182,12 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
|
||||
auto final_type = result_type({queries, keys, values});
|
||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
@ -197,6 +202,7 @@ array scaled_dot_product_attention(
|
||||
* * batch size > 1
|
||||
* * query sequence length > 1
|
||||
* * non-null mask
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
@ -245,7 +251,7 @@ array scaled_dot_product_attention(
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
bool implementation_supports_use_case = batch_dim == 1 &&
|
||||
query_sequence_length == 1 && !mask.has_value() &&
|
||||
query_head_dim == supported_head_dim;
|
||||
query_head_dim == supported_head_dim && final_type != bfloat16;
|
||||
|
||||
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
||||
auto out = array(
|
||||
|
Loading…
Reference in New Issue
Block a user