More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -670,8 +670,7 @@ array scaled_dot_product_attention(
supports_sdpa_full || supports_sdpa_vector;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,