mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
comment + fix
This commit is contained in:
parent
4d95cb24b4
commit
850ad01914
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/ternary.h"
|
#include "mlx/backend/common/ternary.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/ternary_ops.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ __global__ void ternary_g(
|
|||||||
b_strides.data(),
|
b_strides.data(),
|
||||||
c_strides.data(),
|
c_strides.data(),
|
||||||
ndim);
|
ndim);
|
||||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,7 +79,6 @@ template <typename Op>
|
|||||||
void ternary_op_gpu_inplace(
|
void ternary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
assert(inputs.size() > 1);
|
assert(inputs.size() > 1);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
@ -162,20 +161,19 @@ template <typename Op>
|
|||||||
void ternary_op_gpu(
|
void ternary_op_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
std::string_view op,
|
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
auto topt = get_ternary_op_type(a, b, c);
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
set_ternary_op_output_data(a, b, c, out, topt);
|
set_ternary_op_output_data(a, b, c, out, topt);
|
||||||
ternary_op_gpu_inplace<Op>(inputs, out, op, s);
|
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("select::eval_gpu");
|
nvtx3::scoped_range r("select::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
ternary_op_gpu<cu::Select>(inputs, out, get_primitive_string(this), s);
|
ternary_op_gpu<cu::Select>(inputs, out, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user