This commit is contained in:
Awni Hannun 2025-06-12 20:24:23 -07:00
parent 850ad01914
commit f07eb684a6

View File

@ -1,5 +1,4 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
@ -80,7 +79,6 @@ void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
const auto& c = inputs[2];
@ -94,7 +92,7 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
using DType = cuda_type_t<CTYPE>;
auto topt = get_ternary_op_type(a, b, c);
@ -110,7 +108,7 @@ void ternary_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = &cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(