[CUDA] Fix reductions (#2314)

This commit is contained in:
Angelos Katharopoulos
2025-06-27 12:59:20 -07:00
committed by GitHub
parent 2c11d10f8d
commit 772f471ff2
16 changed files with 862 additions and 419 deletions

View File

@@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(