Fix compilation error from integral_constant (#2326)

This commit is contained in:
Cheng 2025-07-02 22:04:38 +09:00 committed by GitHub
parent cfb6a244ea
commit e76e9b87f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 10 deletions

View File

@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu(
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm_vjp<
DataType,
has_w_constant(),
has_w_constant.value,
block_dim(),
N_READS>;
kernel<<<n_rows, block_dim(), 0, stream>>>(

View File

@ -319,10 +319,10 @@ void row_reduce_looped(
T,
U,
OP,
reduce_ndim(),
threads_constant(),
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant();
block.x = threads_constant.value;
});
});

View File

@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu(
constexpr int N_READS = 4;
auto kernel = cu::rms_norm_vjp<
DataType,
has_w_constant(),
has_w_constant.value,
block_dim(),
N_READS>;
kernel<<<n_rows, block_dim(), 0, stream>>>(

View File

@ -315,7 +315,8 @@ void RoPE::eval_gpu(
dispatch_bool(forward_, [&](auto forward) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
@ -327,8 +328,8 @@ void RoPE::eval_gpu(
mat_size,
dims);
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional(), forward()>;
auto kernel = cu::
rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>(
@ -341,7 +342,8 @@ void RoPE::eval_gpu(
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
@ -359,7 +361,7 @@ void RoPE::eval_gpu(
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional(), forward()>;
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;