comment + fix

This commit is contained in:
Awni Hannun 2025-06-12 16:34:44 -07:00
parent 4d95cb24b4
commit 850ad01914

View File

@ -2,8 +2,8 @@
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/ternary_ops.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -69,7 +69,7 @@ __global__ void ternary_g(
b_strides.data(), b_strides.data(),
c_strides.data(), c_strides.data(),
ndim); ndim);
out[index] = Op{}(a[a_idx], b[b_idx]); out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
} }
} }
@ -79,7 +79,6 @@ template <typename Op>
void ternary_op_gpu_inplace( void ternary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
std::string_view op,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1); assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
@ -162,20 +161,19 @@ template <typename Op>
void ternary_op_gpu( void ternary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
std::string_view op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto& c = inputs[2]; auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt); set_ternary_op_output_data(a, b, c, out, topt);
ternary_op_gpu_inplace<Op>(inputs, out, op, s); ternary_op_gpu_inplace<Op>(inputs, out, s);
} }
void Select::eval_gpu(const std::vector<array>& inputs, array& out) { void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("select::eval_gpu"); nvtx3::scoped_range r("select::eval_gpu");
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string(this), s); ternary_op_gpu<cu::Select>(inputs, out, s);
} }
} // namespace mlx::core } // namespace mlx::core