mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation error from integral_constant (#2326)
This commit is contained in:
@@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu(
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
|
||||
Reference in New Issue
Block a user