mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -670,8 +670,7 @@ array scaled_dot_product_attention(
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
|
||||
Reference in New Issue
Block a user