diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 8debf2267..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -5,8 +5,8 @@ template ::n> device const bool* a, device const T* b, device const T* c, - constant uint& size, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; for (int i = 0; i < N && (index + i) < size; ++i) { diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index e81ae1562..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -113,10 +113,10 @@ void ternary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims; if (large) { - compute_encoder.set_bytes(out.data_size(), 2); + compute_encoder.set_bytes(out.data_size(), 4); grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); } else { - compute_encoder.set_bytes(out.data_size(), 2); + compute_encoder.set_bytes(out.data_size(), 4); grid_dims = MTL::Size(nthreads, 1, 1); } compute_encoder.dispatch_threads(grid_dims, group_dims);