mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 08:11:13 +08:00
Fix compilation error from integral_constant (#2326)
This commit is contained in:
parent
cfb6a244ea
commit
e76e9b87f0
@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu(
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::layer_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
|
@ -319,10 +319,10 @@ void row_reduce_looped(
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim(),
|
||||
threads_constant(),
|
||||
reduce_ndim.value,
|
||||
threads_constant.value,
|
||||
N_READS>;
|
||||
block.x = threads_constant();
|
||||
block.x = threads_constant.value;
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu(
|
||||
constexpr int N_READS = 4;
|
||||
auto kernel = cu::rms_norm_vjp<
|
||||
DataType,
|
||||
has_w_constant(),
|
||||
has_w_constant.value,
|
||||
block_dim(),
|
||||
N_READS>;
|
||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||
|
@ -315,7 +315,8 @@ void RoPE::eval_gpu(
|
||||
dispatch_bool(forward_, [&](auto forward) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (single && !with_freqs) {
|
||||
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
@ -327,8 +328,8 @@ void RoPE::eval_gpu(
|
||||
mat_size,
|
||||
dims);
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional(), forward()>;
|
||||
auto kernel = cu::
|
||||
rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
kernel<<<grid, block, 0, stream>>>(
|
||||
@ -341,7 +342,8 @@ void RoPE::eval_gpu(
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else if (with_freqs) {
|
||||
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
@ -359,7 +361,7 @@ void RoPE::eval_gpu(
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional(), forward()>;
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
|
Loading…
Reference in New Issue
Block a user