comment + fix

This commit is contained in:
Awni Hannun 2025-06-12 16:34:44 -07:00
parent 4d95cb24b4
commit 850ad01914

View File

@ -2,8 +2,8 @@
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/ternary_ops.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@ -69,7 +69,7 @@ __global__ void ternary_g(
b_strides.data(),
c_strides.data(),
ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
}
}
@ -79,7 +79,6 @@ template <typename Op>
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
@ -162,20 +161,19 @@ template <typename Op>
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
ternary_op_gpu_inplace<Op>(inputs, out, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("select::eval_gpu");
auto& s = out.primitive().stream();
ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string(this), s);
ternary_op_gpu<cu::Select>(inputs, out, s);
}
} // namespace mlx::core