mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
comment + fix
This commit is contained in:
parent
4d95cb24b4
commit
850ad01914
@ -2,8 +2,8 @@
|
||||
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/ternary_ops.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -69,7 +69,7 @@ __global__ void ternary_g(
|
||||
b_strides.data(),
|
||||
c_strides.data(),
|
||||
ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,7 +79,6 @@ template <typename Op>
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
@ -162,20 +161,19 @@ template <typename Op>
|
||||
void ternary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
std::string_view op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("select::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string(this), s);
|
||||
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user