Fix compilation error from integral_constant (#2326)

This commit is contained in:
Cheng
2025-07-02 22:04:38 +09:00
committed by GitHub
parent cfb6a244ea
commit e76e9b87f0
4 changed files with 12 additions and 10 deletions

View File

@@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu(
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm_vjp<
DataType,
has_w_constant(),
has_w_constant.value,
block_dim(),
N_READS>;
kernel<<<n_rows, block_dim(), 0, stream>>>(