From f07eb684a67bba8f079ac24fa675b0f81fc66e2c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 20:24:23 -0700 Subject: [PATCH] fix --- mlx/backend/cuda/ternary.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index acfbdad47..bb79d4249 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -1,5 +1,4 @@ // Copyright © 2025 Apple Inc. - #include "mlx/backend/common/ternary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/ternary_ops.cuh" @@ -80,7 +79,6 @@ void ternary_op_gpu_inplace( const std::vector& inputs, array& out, const Stream& s) { - assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; const auto& c = inputs[2]; @@ -94,7 +92,7 @@ void ternary_op_gpu_inplace( encoder.set_input_array(c); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); @@ -110,7 +108,7 @@ void ternary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = &cu::ternary_g_nd; + auto kernel = cu::ternary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); kernel<<>>(