CUDA backend: indexing ops (#2277)

This commit is contained in:
Cheng
2025-06-13 13:44:19 +09:00
committed by GitHub
parent 2188199ff8
commit c8b4787e4e
11 changed files with 830 additions and 6 deletions

View File

@@ -148,24 +148,32 @@ bool compiler_supports_device_sass(Device& device) {
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",
INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh",
INCLUDE_PREFIX "indexing.cuh",
INCLUDE_PREFIX "scatter_ops.cuh",
INCLUDE_PREFIX "unary_ops.cuh",
INCLUDE_PREFIX "ternary_ops.cuh",
INCLUDE_PREFIX "utils.cuh",
};
#undef INCLUDE_PREFIX
constexpr const char* g_headers[] = {
jit_source_atomic_ops,
jit_source_binary_ops,
jit_source_cast_op,
jit_source_config,
jit_source_cucomplex_math,
jit_source_fp16_math,
jit_source_indexing,
jit_source_scatter_ops,
jit_source_unary_ops,
jit_source_ternary_ops,
jit_source_utils,
};