Faster triu, tril, where with scalar (#2644)

This commit is contained in:
Awni Hannun
2025-10-02 12:21:27 -07:00
committed by GitHub
parent e88f2d4a8e
commit b0cc71ae71
7 changed files with 101 additions and 42 deletions

View File

@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) {
if (topt == TernaryOpType::VectorVectorVector ||
topt == TernaryOpType::ScalarScalarScalar) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} else {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
}
});
}