diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 232de2f31..9a95e08ac 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -106,7 +106,6 @@ array rope( return std::vector{concatenate(outs, 2, s)}; } }; - // TODO change to condition for using custom prim auto stream = to_stream(s); if (stream.device == Device::gpu && x.shape(-1) == dims) { return array( @@ -183,6 +182,12 @@ array scaled_dot_product_attention( } auto final_type = result_type({queries, keys, values}); + if (!is_floating_point(final_type) || is_complex(final_type)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Received unsupported type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } auto q = astype(queries, final_type, s); auto k = astype(keys, final_type, s); @@ -197,6 +202,7 @@ array scaled_dot_product_attention( * * batch size > 1 * * query sequence length > 1 * * non-null mask + * * dtype is not fp32 or fp16 */ bool needs_mask = mask.has_value(); auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( @@ -245,7 +251,7 @@ array scaled_dot_product_attention( const size_t query_sequence_length = q.shape(2); bool implementation_supports_use_case = batch_dim == 1 && query_sequence_length == 1 && !mask.has_value() && - query_head_dim == supported_head_dim; + query_head_dim == supported_head_dim && final_type != bfloat16; if (stream.device == Device::gpu && implementation_supports_use_case) { auto out = array(