mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 16:51:13 +08:00
Fix compilation error from integral_constant (#2326)
This commit is contained in:
parent
cfb6a244ea
commit
e76e9b87f0
@ -364,7 +364,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
auto kernel = cu::layer_norm_vjp<
|
auto kernel = cu::layer_norm_vjp<
|
||||||
DataType,
|
DataType,
|
||||||
has_w_constant(),
|
has_w_constant.value,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
N_READS>;
|
N_READS>;
|
||||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
|
@ -319,10 +319,10 @@ void row_reduce_looped(
|
|||||||
T,
|
T,
|
||||||
U,
|
U,
|
||||||
OP,
|
OP,
|
||||||
reduce_ndim(),
|
reduce_ndim.value,
|
||||||
threads_constant(),
|
threads_constant.value,
|
||||||
N_READS>;
|
N_READS>;
|
||||||
block.x = threads_constant();
|
block.x = threads_constant.value;
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -320,7 +320,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::rms_norm_vjp<
|
auto kernel = cu::rms_norm_vjp<
|
||||||
DataType,
|
DataType,
|
||||||
has_w_constant(),
|
has_w_constant.value,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
N_READS>;
|
N_READS>;
|
||||||
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
|
@ -315,7 +315,8 @@ void RoPE::eval_gpu(
|
|||||||
dispatch_bool(forward_, [&](auto forward) {
|
dispatch_bool(forward_, [&](auto forward) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if (single && !with_freqs) {
|
if (single && !with_freqs) {
|
||||||
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
|
auto kernel =
|
||||||
|
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
@ -327,8 +328,8 @@ void RoPE::eval_gpu(
|
|||||||
mat_size,
|
mat_size,
|
||||||
dims);
|
dims);
|
||||||
} else if (single) {
|
} else if (single) {
|
||||||
auto kernel =
|
auto kernel = cu::
|
||||||
cu::rope_single_freqs<DataType, traditional(), forward()>;
|
rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
@ -341,7 +342,8 @@ void RoPE::eval_gpu(
|
|||||||
dims,
|
dims,
|
||||||
inputs[2].strides(0));
|
inputs[2].strides(0));
|
||||||
} else if (with_freqs) {
|
} else if (with_freqs) {
|
||||||
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
|
auto kernel =
|
||||||
|
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||||
uint3 dims =
|
uint3 dims =
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
dims.z = (dims.z + 3) / 4;
|
dims.z = (dims.z + 3) / 4;
|
||||||
@ -359,7 +361,7 @@ void RoPE::eval_gpu(
|
|||||||
dims,
|
dims,
|
||||||
inputs[2].strides(0));
|
inputs[2].strides(0));
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::rope<DataType, traditional(), forward()>;
|
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||||
uint3 dims =
|
uint3 dims =
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
dims.z = (dims.z + 3) / 4;
|
dims.z = (dims.z + 3) / 4;
|
||||||
|
Loading…
Reference in New Issue
Block a user