From 850ad019143edef8f7968454ccdfceeb57ae36e3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Jun 2025 16:34:44 -0700 Subject: [PATCH] comment + fix --- mlx/backend/cuda/ternary.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 0a2c67f76..acfbdad47 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -2,8 +2,8 @@ #include "mlx/backend/common/ternary.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/ternary_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -69,7 +69,7 @@ __global__ void ternary_g( b_strides.data(), c_strides.data(), ndim); - out[index] = Op{}(a[a_idx], b[b_idx]); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); } } @@ -79,7 +79,6 @@ template void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - std::string_view op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; @@ -162,20 +161,19 @@ template void ternary_op_gpu( const std::vector& inputs, array& out, - std::string_view op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; auto topt = get_ternary_op_type(a, b, c); set_ternary_op_output_data(a, b, c, out, topt); - ternary_op_gpu_inplace(inputs, out, op, s); + ternary_op_gpu_inplace(inputs, out, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("select::eval_gpu"); auto& s = out.primitive().stream(); - ternary_op_gpu(inputs, out, get_primitive_string(this), s); + ternary_op_gpu(inputs, out, s); } } // namespace mlx::core