This commit is contained in:
Awni Hannun 2025-06-12 20:24:23 -07:00
parent 850ad01914
commit f07eb684a6

View File

@ -1,5 +1,4 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/ternary.h" #include "mlx/backend/common/ternary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/device/ternary_ops.cuh"
@ -80,7 +79,6 @@ void ternary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const Stream& s) { const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
const auto& c = inputs[2]; const auto& c = inputs[2];
@ -94,7 +92,7 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, { MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
using DType = cuda_type_t<CTYPE>; using DType = cuda_type_t<CTYPE>;
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
@ -110,7 +108,7 @@ void ternary_op_gpu_inplace(
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, { MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = &cu::ternary_g_nd<Op, DType, IdxT, NDIM>; auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large); get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(