route to fallback for bfloat (#794)

This commit is contained in:
Awni Hannun 2024-03-06 15:39:12 -08:00 committed by GitHub
parent 1074674e32
commit afd5274049
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -106,7 +106,6 @@ array rope(
return std::vector<array>{concatenate(outs, 2, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
@ -183,6 +182,12 @@ array scaled_dot_product_attention(
}
auto final_type = result_type({queries, keys, values});
if (!is_floating_point(final_type) || is_complex(final_type)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
@ -197,6 +202,7 @@ array scaled_dot_product_attention(
* * batch size > 1
* * query sequence length > 1
* * non-null mask
* * dtype is not fp32 or fp16
*/
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
@ -245,7 +251,7 @@ array scaled_dot_product_attention(
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = batch_dim == 1 &&
query_sequence_length == 1 && !mask.has_value() &&
query_head_dim == supported_head_dim;
query_head_dim == supported_head_dim && final_type != bfloat16;
if (stream.device == Device::gpu && implementation_supports_use_case) {
auto out = array(