mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
route to fallback for bfloat (#794)
This commit is contained in:
parent
1074674e32
commit
afd5274049
10
mlx/fast.cpp
10
mlx/fast.cpp
@ -106,7 +106,6 @@ array rope(
|
|||||||
return std::vector<array>{concatenate(outs, 2, s)};
|
return std::vector<array>{concatenate(outs, 2, s)};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// TODO change to condition for using custom prim
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||||
return array(
|
return array(
|
||||||
@ -183,6 +182,12 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto final_type = result_type({queries, keys, values});
|
auto final_type = result_type({queries, keys, values});
|
||||||
|
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||||
|
<< final_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto q = astype(queries, final_type, s);
|
auto q = astype(queries, final_type, s);
|
||||||
auto k = astype(keys, final_type, s);
|
auto k = astype(keys, final_type, s);
|
||||||
@ -197,6 +202,7 @@ array scaled_dot_product_attention(
|
|||||||
* * batch size > 1
|
* * batch size > 1
|
||||||
* * query sequence length > 1
|
* * query sequence length > 1
|
||||||
* * non-null mask
|
* * non-null mask
|
||||||
|
* * dtype is not fp32 or fp16
|
||||||
*/
|
*/
|
||||||
bool needs_mask = mask.has_value();
|
bool needs_mask = mask.has_value();
|
||||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||||
@ -245,7 +251,7 @@ array scaled_dot_product_attention(
|
|||||||
const size_t query_sequence_length = q.shape(2);
|
const size_t query_sequence_length = q.shape(2);
|
||||||
bool implementation_supports_use_case = batch_dim == 1 &&
|
bool implementation_supports_use_case = batch_dim == 1 &&
|
||||||
query_sequence_length == 1 && !mask.has_value() &&
|
query_sequence_length == 1 && !mask.has_value() &&
|
||||||
query_head_dim == supported_head_dim;
|
query_head_dim == supported_head_dim && final_type != bfloat16;
|
||||||
|
|
||||||
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
if (stream.device == Device::gpu && implementation_supports_use_case) {
|
||||||
auto out = array(
|
auto out = array(
|
||||||
|
Loading…
Reference in New Issue
Block a user