mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
no sdpa in grad (#2054)
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@@ -772,7 +773,7 @@ array scaled_dot_product_attention(
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
|
||||
Reference in New Issue
Block a user