no sdpa in grad (#2054)

This commit is contained in:
Awni Hannun
2025-04-08 19:13:54 -07:00
committed by GitHub
parent 00794c42bc
commit e5d35aa187
3 changed files with 22 additions and 8 deletions

View File

@@ -9,6 +9,7 @@
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast {
@@ -772,7 +773,7 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
}
if (implementation_supports_use_case) {
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),