mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix compilation error from integral_constant (#2326)
This commit is contained in:
@@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu(
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
|
||||
Reference in New Issue
Block a user