From b2dd60c1ddbd14e44ee135e8a1929f2f4764dfdb Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 30 May 2025 07:47:32 +0000 Subject: [PATCH 1/3] CUDA backend: compile --- mlx/backend/cuda/CMakeLists.txt | 23 ++ mlx/backend/cuda/bin2h.cmake | 150 +++++++++++++ mlx/backend/cuda/compiled.cpp | 228 +++++++++++++++++++ mlx/backend/cuda/jit_module.cpp | 340 +++++++++++++++++++++++++++++ mlx/backend/cuda/jit_module.h | 113 ++++++++++ mlx/backend/cuda/kernels/config.h | 12 + mlx/backend/cuda/kernels/utils.cuh | 8 +- mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/utils.cpp | 17 ++ mlx/backend/cuda/utils.h | 5 + 10 files changed, 890 insertions(+), 7 deletions(-) create mode 100644 mlx/backend/cuda/bin2h.cmake create mode 100644 mlx/backend/cuda/compiled.cpp create mode 100644 mlx/backend/cuda/jit_module.cpp create mode 100644 mlx/backend/cuda/jit_module.h create mode 100644 mlx/backend/cuda/kernels/config.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d5041b2ae..89245da53 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu @@ -18,6 +19,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu @@ -37,6 +39,24 @@ target_sources( target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) +# Embed kernel sources in binary for JIT compilation. +file( + GLOB MLX_JIT_SOURCES + RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh") +string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) +add_custom_command( + OUTPUT gen/cuda_jit_sources.h + COMMAND + ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR} + -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P + "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake" + DEPENDS bin2h.cmake ${MLX_JIT_SOURCES}) +add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h) +add_dependencies(mlx cuda_jit_sources) +target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") + # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") @@ -83,6 +103,9 @@ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) # Use cublasLt. target_link_libraries(mlx PRIVATE CUDA::cublasLt) +# Use NVRTC and driver APIs. +target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) + # Suppress nvcc warnings on MLX headers. target_compile_options(mlx PRIVATE $<$:-Xcudafe --diag_suppress=997>) diff --git a/mlx/backend/cuda/bin2h.cmake b/mlx/backend/cuda/bin2h.cmake new file mode 100644 index 000000000..b791d3d1a --- /dev/null +++ b/mlx/backend/cuda/bin2h.cmake @@ -0,0 +1,150 @@ +# Based on: https://github.com/sivachandran/cmake-bin2h +# +# Copyright 2020 Sivachandran Paramasivam +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +include(CMakeParseArguments) + +# Function to wrap a given string into multiple lines at the given column +# position. +# +# Parameters: +# +# * VARIABLE - The name of the CMake variable holding the string. +# * AT_COLUMN - The column position at which string will be wrapped. +function(WRAP_STRING) + set(oneValueArgs VARIABLE AT_COLUMN) + cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN}) + + string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength) + math(EXPR offset "0") + + while(stringLength GREATER 0) + if(stringLength GREATER ${WRAP_STRING_AT_COLUMN}) + math(EXPR length "${WRAP_STRING_AT_COLUMN}") + else() + math(EXPR length "${stringLength}") + endif() + + string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line) + set(lines "${lines}\n ${line}") + + math(EXPR stringLength "${stringLength} - ${length}") + math(EXPR offset "${offset} + ${length}") + endwhile() + + set(${WRAP_STRING_VARIABLE} + "${lines}" + PARENT_SCOPE) +endfunction() + +# Function to embed contents of a file as byte array in C/C++ header file(.h). +# The header file will contain a byte array and integer variable holding the +# size of the array. +# +# Parameters: +# +# * SOURCE_FILES - The paths of source files whose contents will be embedded in +# the header file. +# * VARIABLE_NAME - The name of the variable for the byte array. The string +# "_SIZE" will be append to this name and will be used a variable name for +# size variable. +# * HEADER_FILE - The path of header file. +# * APPEND - If specified appends to the header file instead of overwriting it +# * HEADER_NAMESPACE - The namespace, where the array should be located in. +# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte +# array. +# +# Usage: +# +# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG") +function(BIN2H) + set(options APPEND NULL_TERMINATE) + set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE) + set(multiValueArgs SOURCE_FILES) + cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(arrayDefinition "") + foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES) + # get filename without extension + get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE) + # convert the filename to a valid C identifier + string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME) + + # reads source file contents as hex string + file(READ ${SOURCE_FILE} hexString HEX) + + # append null + if(BIN2H_NULL_TERMINATE) + string(APPEND hexString "00") + endif() + + # wraps the hex string into multiple lines + wrap_string(VARIABLE hexString AT_COLUMN 24) + + # strip the © in source code + string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString}) + + string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues + ${arrayValues}) + + # make a full variable name for the array + set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}") + + # declares byte array and the length variables + string(APPEND arrayDefinition + "constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n") + endforeach() + + # add namespace wrapper if defined + if(DEFINED BIN2H_HEADER_NAMESPACE) + set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {") + set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}") + set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n") + endif() + + set(arrayIncludes "#pragma once") + string(PREPEND declarations "${arrayIncludes}\n\n") + + if(BIN2H_APPEND) + file(APPEND ${BIN2H_HEADER_FILE} "${declarations}") + else() + file(WRITE ${BIN2H_HEADER_FILE} "${declarations}") + endif() +endfunction() + +# ----------------------------- CLI args ----------------------------- + +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) +foreach(source ${MLX_JIT_SOURCES_LIST}) + list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}") +endforeach() + +bin2h( + SOURCE_FILES + ${MLX_JIT_SOURCES_ABS} + NULL_TERMINATE + VARIABLE_NAME + "jit_source" + HEADER_NAMESPACE + "mlx::core" + HEADER_FILE + "${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h") diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp new file mode 100644 index 000000000..de004b482 --- /dev/null +++ b/mlx/backend/cuda/compiled.cpp @@ -0,0 +1,228 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +struct FusedKernelBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& outputs; + const std::vector& tape; + const std::function& is_constant; + + void build(const char* name, bool contiguous) { + NodeNamer namer; + + // Function parameters. + std::vector params; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + continue; + } + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + params.push_back( + fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname)); + if (!is_scalar(x) && !contiguous) { + params.push_back(fmt::format( + "const __grid_constant__ cuda::std::array {}_strides", + xname)); + } + } + for (const auto& x : outputs) { + params.push_back(fmt::format( + "{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x))); + } + if (!contiguous) { + params.push_back( + "const __grid_constant__ cuda::std::array shape"); + } + params.push_back("IdxT size"); + + // Build function signature. + if (contiguous) { + os += "template \n"; + } else { + os += "template \n"; + } + os += fmt::format("__global__ void {}(\n", kernel_name + name); + for (size_t i = 0; i < params.size(); ++i) { + os += " "; + os += params[i]; + if (i != params.size() - 1) { + os += ",\n"; + } + } + os += ") {\n"; + + // Index. + os += + " IdxT index = cg::this_grid().thread_rank();\n" + " if (index >= size) {\n" + " return;\n" + " }\n"; + + // Read inputs. + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + value = fmt::format("static_cast<{}>({})", type, ss.str()); + } else if (is_scalar(x)) { + value = fmt::format("{}[0]", xname); + } else if (contiguous) { + value = fmt::format("{}[index]", xname); + } else { + std::string index = fmt::format( + "elem_to_loc_nd(index, shape.data(), {}_strides.data())", + xname); + value = fmt::format("{}[{}]", xname, index); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write tape. + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = fmt::format( + "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); + } else { + std::ostringstream ss; + x.primitive().print(ss); + value = ss.str(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); + } + value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Write output. + for (const auto& x : outputs) { + os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); + } + + os += "}\n"; + } +}; + +} // namespace cu + +constexpr const char* g_jit_includes = R"( +#include "mlx/backend/cuda/kernels/binary_ops.cuh" +#include "mlx/backend/cuda/kernels/unary_ops.cuh" +#include "mlx/backend/cuda/kernels/utils.cuh" + +#include + +)"; + +void Compiled::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("Compiled::eval_gpu"); + auto& s = stream(); + + cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { + // Build source code. + cu::FusedKernelBuilder builder{ + g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; + builder.os += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + builder.build("_contiguous", true); + builder.os += "\n"; + builder.build("_strided", false); + builder.os += "\n} // namespace mlx::core::cu\n"; + // Build kernel names. + std::vector kernel_names = { + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + fmt::format("mlx::core::cu::{}_contiguous", lib_name()), + }; + for (int i = 1; i <= MAX_NDIM; ++i) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i)); + kernel_names.push_back( + fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i)); + } + return std::make_pair(std::move(builder.os), std::move(kernel_names)); + }); + + // Collapse contiguous dims to route to a faster kernel if possible. Also + // handle all broadcasting. + auto [contiguous, shape, strides_vec] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); + + // Put inputs. + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { + continue; + } + const auto& x = inputs[i]; + mod.append_arg(x); + if (!contiguous && !is_scalar(x)) { + mod.append_arg(strides_vec[strides_index++]); + } + } + + // Put outputs. + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + for (auto& x : outputs) { + mod.append_arg(x); + } + + // Put shape and size. + if (!contiguous) { + mod.append_arg(shape); + } + if (large) { + mod.append_arg(outputs[0].data_size()); + } else { + mod.append_arg(outputs[0].data_size()); + } + + // Launch kernel. + const char* index_type = large ? "int64_t" : "uint32_t"; + std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name()); + if (contiguous) { + kernel_name += fmt::format("_contiguous<{}>", index_type); + } else { + kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type); + } + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, outputs[0], large); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp new file mode 100644 index 000000000..3c00dd7f0 --- /dev/null +++ b/mlx/backend/cuda/jit_module.cpp @@ -0,0 +1,340 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/device.h" + +#include "cuda_jit_sources.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +namespace { + +#define CHECK_NVRTC_ERROR(cmd) check_nvrtc_error(#cmd, (cmd)) + +void check_nvrtc_error(const char* name, nvrtcResult err) { + if (err != NVRTC_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}", name, nvrtcGetErrorString(err))); + } +} + +#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd)) + +void check_cu_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + +// Return the location of the CUDA toolkit. +const char* cuda_home() { + const char* home = std::getenv("CUDA_HOME"); + if (home) { + return home; + } + home = std::getenv("CUDA_PATH"); + if (home) { + return home; + } +#if defined(__linux__) + home = "/usr/local/cuda"; + if (std::filesystem::exists(home)) { + return home; + } +#endif + throw std::runtime_error( + "Environment variable CUDA_HOME or CUDA_PATH is not set."); +} + +// Get the cache directory for storing compiled results. +bool get_ptx_cache_dir(std::filesystem::path* result) { + auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; + if (!std::filesystem::is_directory(path)) { + std::error_code error; + if (!std::filesystem::create_directories(path, error)) { + return false; + } + } + *result = path; + return true; +} + +// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. +bool read_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + std::vector* ptx, + std::vector>* ptx_kernels) { + auto ptx_path = cache_dir / (module_name + ".ptx"); + std::error_code error; + auto ptx_size = std::filesystem::file_size(ptx_path, error); + if (error) { + return false; + } + std::ifstream ptx_file(ptx_path, std::ios::binary); + if (!ptx_file.good()) { + return false; + } + ptx->resize(ptx_size); + ptx_file.read(ptx->data(), ptx_size); + + std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + std::string line; + while (std::getline(txt_file, line)) { + auto tab = line.find('\t'); + if (tab != std::string::npos) { + ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); + } + } + return true; +} + +// Write the |ptx| and |ptx_kernels| to |cache_dir| with |name|. +void write_cached_ptx( + const std::filesystem::path& cache_dir, + const std::string& module_name, + const std::vector& ptx, + const std::vector>& ptx_kernels) { + std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); + if (!ptx.empty()) { + ptx_file.write(&ptx.front(), ptx.size()); + } + std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); + for (const auto& [name, mangled] : ptx_kernels) { + txt_file << name << "\t" << mangled << std::endl; + } +} + +// Return if |device|'s version is not newer than |major|.|minor| version. +inline bool version_lower_equal(Device& device, int major, int minor) { + if (device.compute_capability_major() < major) { + return true; + } else if (device.compute_capability_major() == major) { + return device.compute_capability_minor() <= minor; + } else { + return false; + } +} + +// Return whether NVRTC supports compiling to |device|'s SASS code. +bool compiler_supports_device_sass(Device& device) { + int nvrtc_major, nvrtc_minor; + CHECK_NVRTC_ERROR(nvrtcVersion(&nvrtc_major, &nvrtc_minor)); + if (nvrtc_major < 9) { + return false; + } else if (nvrtc_major == 9) { + return version_lower_equal(device, 7, 2); + } else if (nvrtc_major == 10) { + return version_lower_equal(device, 7, 5); + } else if (nvrtc_major == 11 && nvrtc_minor == 0) { + return version_lower_equal(device, 8, 0); + } else if (nvrtc_major == 11 && nvrtc_minor < 8) { + return version_lower_equal(device, 8, 6); + } else { + return true; + } +} + +#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" + +constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "binary_ops.cuh", + INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "config.h", + INCLUDE_PREFIX "cucomplex_math.cuh", + INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "unary_ops.cuh", + INCLUDE_PREFIX "utils.cuh", +}; + +#undef INCLUDE_PREFIX + +constexpr const char* g_headers[] = { + jit_source_binary_ops, + jit_source_cast_op, + jit_source_config, + jit_source_cucomplex_math, + jit_source_fp16_math, + jit_source_unary_ops, + jit_source_utils, +}; + +} // namespace + +JitModule::JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder) { + // Check cache. + std::filesystem::path cache_dir; + std::vector ptx; + std::vector> ptx_kernels; + if (!get_ptx_cache_dir(&cache_dir) || + !read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) { + // Create program. + auto [source_code, kernel_names] = builder(); + nvrtcProgram prog; + CHECK_NVRTC_ERROR(nvrtcCreateProgram( + &prog, + source_code.c_str(), + (module_name + ".cu").c_str(), + std::size(g_headers), + g_headers, + g_include_names)); + std::unique_ptr prog_freer( + &prog, + [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); + for (const auto& name : kernel_names) { + CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); + } + + // Compile program. + bool use_sass = compiler_supports_device_sass(device); + std::string compute = fmt::format( + "--gpu-architecture={}_{}{}", + use_sass ? "sm" : "compute", + device.compute_capability_major(), + device.compute_capability_minor()); + std::string include = fmt::format("--include-path={}/include", cuda_home()); + const char* args[] = {compute.c_str(), include.c_str()}; + nvrtcResult compile_result = + nvrtcCompileProgram(prog, std::size(args), args); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size; + CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); + throw std::runtime_error( + fmt::format("Failed to compile kernel: {}.", log.data())); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); + ptx_kernels.emplace_back(name, mangled); + } + + // Get ptx data. + size_t ptx_size; + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); + } + ptx.resize(ptx_size, 0); + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); + } + write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); + } + + // Load module. + char jit_log[4089] = {}; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; + CUresult jit_result = cuModuleLoadDataEx( + &module_, ptx.data(), std::size(options), options, values); + if (jit_result != CUDA_SUCCESS) { + throw std::runtime_error(fmt::format( + "Failed to load compiled {} kernel: {}.", module_name, jit_log)); + } + + // Load kernels. + for (const auto& [name, mangled] : ptx_kernels) { + CUfunction kernel; + CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + kernels_[name] = kernel; + } +} + +JitModule::~JitModule() { + CHECK_CU_ERROR(cuModuleUnload(module_)); +} + +void JitModule::launch_kernel( + CUstream stream, + const std::string& kernel_name, + const array& arr, + bool large, + int work_per_thread) { + CUfunction kernel = get_kernel(kernel_name); + size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); + int _, block_dim; + CHECK_CU_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + if (block_dim > nthreads) { + block_dim = nthreads; + } + Dims num_blocks{1, 1, 1}; + if (large) { + num_blocks = + get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread); + std::get<0>(num_blocks) = + (std::get<0>(num_blocks) + block_dim - 1) / block_dim; + } else { + std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim; + } + launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1}); +} + +void JitModule::launch_kernel( + CUstream stream, + CUfunction kernel, + Dims num_blocks, + Dims block_dims) { + CHECK_CU_ERROR(cuLaunchKernel( + kernel, + std::get<0>(num_blocks), + std::get<1>(num_blocks), + std::get<2>(num_blocks), + std::get<0>(block_dims), + std::get<1>(block_dims), + std::get<2>(block_dims), + 0, + stream, + args_.data(), + nullptr)); + args_.clear(); + storage_.clear(); +} + +CUfunction JitModule::get_kernel(const std::string& kernel_name) { + auto it = kernels_.find(kernel_name); + if (it == kernels_.end()) { + throw std::runtime_error( + fmt::format("There is no kernel named {}.", kernel_name)); + } + return it->second; +} + +void JitModule::append_ptr_arg(const void* v) { + args_.push_back(const_cast(v)); +} + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder) { + static std::unordered_map map; + auto it = map.find(name); + if (it == map.end()) { + it = map.try_emplace(name, cu::device(device), name, builder).first; + } + return it->second; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h new file mode 100644 index 000000000..fcaa1fb3e --- /dev/null +++ b/mlx/backend/cuda/jit_module.h @@ -0,0 +1,113 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/kernels/config.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +using KernelBuilderResult = std::pair< + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + + void append_arg(const array& a) { + append_arg(reinterpret_cast(a.data())); + } + + template + void append_arg(T val) { + storage_.emplace_back(val); + append_ptr_arg(&storage_.back()); + } + + template + void append_arg(std::vector vec) { + if (vec.empty()) { + // The nullptr can not be used as arg, pass something not null. + append_arg(std::monostate{}); + } else { + append_ptr_arg(vec.data()); + storage_.emplace_back(std::move(vec)); + } + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim_arg(const std::vector& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + std::vector copied(NDIM); + std::copy(vec.begin(), vec.end(), copied.data()); + append_arg(std::move(copied)); + } + + // Launch kernel with |kernel_name| that each thread works on + // |work_per_thread| elements of |arr|. + void launch_kernel( + CUstream stream, + const std::string& kernel_name, + const array& arr, + bool large, + int work_per_thread = 1); + + void launch_kernel( + CUstream stream, + CUfunction kernel, + Dims num_blocks, + Dims block_dims); + + CUfunction get_kernel(const std::string& kernel_name); + + private: + void append_ptr_arg(const void* v); + + CUmodule module_{nullptr}; + std::unordered_map kernels_; + std::vector args_; + + // The cuLaunchKernel API requires passing pointers to arguments so store + // temporary values untill kernel is launched. + using Arg = std::variant< + std::monostate, + CUdeviceptr, + int32_t, + uint32_t, + int64_t, + std::vector, + std::vector, + std::vector>; + std::deque storage_; +}; + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/config.h b/mlx/backend/cuda/kernels/config.h new file mode 100644 index 000000000..0933cc8b5 --- /dev/null +++ b/mlx/backend/cuda/kernels/config.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both CUDA kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 8 + +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/kernels/utils.cuh index 7636710dc..e59095996 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/kernels/utils.cuh @@ -8,6 +8,8 @@ #pragma once +#include "mlx/backend/cuda/kernels/config.h" + #include #include #include @@ -21,14 +23,8 @@ namespace mlx::core::cu { // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// -// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in -// warpSize variable exists, using it would prevent compile-time optimizations. -#define WARP_SIZE 32 - // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. -#define MAX_NDIM 8 - using Shape = cuda::std::array; using Strides = cuda::std::array; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 8de4f92f9..ded0d80c7 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback( NO_GPU(ArgPartition) NO_GPU(BlockMaskedMM) -NO_GPU_MULTI(Compiled) NO_GPU(Convolution) NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2a11a518e..2f5e2a4c8 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" #include @@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +const char* dtype_to_cuda_type(const Dtype& dtype) { + if (dtype == float16) { + return "__half"; + } + if (dtype == bfloat16) { + return "__nv_bfloat16"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (dtype == DTYPE) { \ + return #CPP_TYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return nullptr; +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6eaec8984..6d98cdcd5 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -12,6 +12,8 @@ namespace cu { class Device; } +struct Dtype; + // Cuda stream managed with RAII. class CudaStream { public: @@ -35,4 +37,7 @@ void check_cuda_error(const char* name, cudaError_t err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + } // namespace mlx::core From ef9495fb8f457803b1c2ea6714c1c52ccabe2b1b Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 12 Jun 2025 23:39:06 +0000 Subject: [PATCH 2/3] Rename kernels/ to device/ --- mlx/backend/cuda/binary.cu | 4 ++-- mlx/backend/cuda/compiled.cpp | 6 +++--- mlx/backend/cuda/copy/copy.cuh | 2 +- mlx/backend/cuda/{kernels => device}/arange.cuh | 0 mlx/backend/cuda/{kernels => device}/binary_ops.cuh | 2 +- mlx/backend/cuda/{kernels => device}/cast_op.cuh | 0 mlx/backend/cuda/{kernels => device}/config.h | 0 mlx/backend/cuda/{kernels => device}/cucomplex_math.cuh | 0 mlx/backend/cuda/{kernels => device}/fp16_math.cuh | 0 mlx/backend/cuda/{kernels => device}/unary_ops.cuh | 4 ++-- mlx/backend/cuda/{kernels => device}/utils.cuh | 2 +- mlx/backend/cuda/jit_module.h | 2 +- mlx/backend/cuda/kernel_utils.cuh | 4 ++-- mlx/backend/cuda/logsumexp.cu | 2 +- mlx/backend/cuda/primitives.cu | 4 ++-- mlx/backend/cuda/reduce/col_reduce.cu | 2 +- mlx/backend/cuda/reduce/reduce.cuh | 2 +- mlx/backend/cuda/reduce/reduce_ops.cuh | 2 +- mlx/backend/cuda/reduce/row_reduce.cu | 2 +- mlx/backend/cuda/reduce/segmented_reduce.cu | 2 +- mlx/backend/cuda/softmax.cu | 4 ++-- mlx/backend/cuda/unary.cu | 4 ++-- 22 files changed, 25 insertions(+), 25 deletions(-) rename mlx/backend/cuda/{kernels => device}/arange.cuh (100%) rename mlx/backend/cuda/{kernels => device}/binary_ops.cuh (99%) rename mlx/backend/cuda/{kernels => device}/cast_op.cuh (100%) rename mlx/backend/cuda/{kernels => device}/config.h (100%) rename mlx/backend/cuda/{kernels => device}/cucomplex_math.cuh (100%) rename mlx/backend/cuda/{kernels => device}/fp16_math.cuh (100%) rename mlx/backend/cuda/{kernels => device}/unary_ops.cuh (98%) rename mlx/backend/cuda/{kernels => device}/utils.cuh (99%) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 360772998..47efc44d2 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -2,9 +2,9 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/binary_ops.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index de004b482..a6b8223e0 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -129,9 +129,9 @@ struct FusedKernelBuilder { } // namespace cu constexpr const char* g_jit_includes = R"( -#include "mlx/backend/cuda/kernels/binary_ops.cuh" -#include "mlx/backend/cuda/kernels/unary_ops.cuh" -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index dd1d09d30..0c1eff774 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -3,8 +3,8 @@ #pragma once #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/kernels/arange.cuh b/mlx/backend/cuda/device/arange.cuh similarity index 100% rename from mlx/backend/cuda/kernels/arange.cuh rename to mlx/backend/cuda/device/arange.cuh diff --git a/mlx/backend/cuda/kernels/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh similarity index 99% rename from mlx/backend/cuda/kernels/binary_ops.cuh rename to mlx/backend/cuda/device/binary_ops.cuh index 3bc30eb02..4779a6f33 100644 --- a/mlx/backend/cuda/kernels/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/kernels/fp16_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include #include diff --git a/mlx/backend/cuda/kernels/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh similarity index 100% rename from mlx/backend/cuda/kernels/cast_op.cuh rename to mlx/backend/cuda/device/cast_op.cuh diff --git a/mlx/backend/cuda/kernels/config.h b/mlx/backend/cuda/device/config.h similarity index 100% rename from mlx/backend/cuda/kernels/config.h rename to mlx/backend/cuda/device/config.h diff --git a/mlx/backend/cuda/kernels/cucomplex_math.cuh b/mlx/backend/cuda/device/cucomplex_math.cuh similarity index 100% rename from mlx/backend/cuda/kernels/cucomplex_math.cuh rename to mlx/backend/cuda/device/cucomplex_math.cuh diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/device/fp16_math.cuh similarity index 100% rename from mlx/backend/cuda/kernels/fp16_math.cuh rename to mlx/backend/cuda/device/fp16_math.cuh diff --git a/mlx/backend/cuda/kernels/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh similarity index 98% rename from mlx/backend/cuda/kernels/unary_ops.cuh rename to mlx/backend/cuda/device/unary_ops.cuh index 6637a6eeb..af7c30e64 100644 --- a/mlx/backend/cuda/kernels/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,8 +2,8 @@ #pragma once -#include "mlx/backend/cuda/kernels/fp16_math.cuh" -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { diff --git a/mlx/backend/cuda/kernels/utils.cuh b/mlx/backend/cuda/device/utils.cuh similarity index 99% rename from mlx/backend/cuda/kernels/utils.cuh rename to mlx/backend/cuda/device/utils.cuh index e59095996..a1d387201 100644 --- a/mlx/backend/cuda/kernels/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -8,7 +8,7 @@ #pragma once -#include "mlx/backend/cuda/kernels/config.h" +#include "mlx/backend/cuda/device/config.h" #include #include diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index fcaa1fb3e..bbfaa45b0 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -4,7 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/cuda/kernels/config.h" +#include "mlx/backend/cuda/device/config.h" #include #include diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 656ddebea..7e957bbbd 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -1,13 +1,13 @@ // Copyright © 2025 Apple Inc. // This file includes host-only utilies for writing CUDA kernels, the difference -// from backend/cuda/kernels/utils.cuh is that the latter file only include +// from backend/cuda/device/utils.cuh is that the latter file only include // device-only code. #pragma once #include "mlx/array.h" -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include #include diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index e539ac559..f57f82ea8 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index ded0d80c7..48b189626 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -1,9 +1,9 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/arange.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/arange.cuh" -#include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/distributed/primitives.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 1ca50d854..9911a6fe0 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 0148022ab..a673e052e 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/reduce.h" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index f06eb8541..832787222 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/kernels/utils.cuh" +#include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 3a5c4a591..ae54a27d6 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/reduce/segmented_reduce.cu b/mlx/backend/cuda/reduce/segmented_reduce.cu index 563b056e4..114d71809 100644 --- a/mlx/backend/cuda/reduce/segmented_reduce.cu +++ b/mlx/backend/cuda/reduce/segmented_reduce.cu @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/kernels/cast_op.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 605fc0df8..fc001ae75 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -1,9 +1,9 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cast_op.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cast_op.cuh" -#include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 0ee31ee28..f9d373455 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -2,10 +2,10 @@ #include "mlx/backend/common/unary.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/iterators/general_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" -#include "mlx/backend/cuda/kernels/unary_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" From ef88907c635c82a0d4881c58e4d53469d7f642b4 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 02:15:09 +0000 Subject: [PATCH 3/3] CUDA backend: indexing ops --- mlx/backend/cuda/CMakeLists.txt | 9 +- mlx/backend/cuda/device/atomic_ops.cuh | 72 ++++ mlx/backend/cuda/device/gather.cuh | 53 +++ mlx/backend/cuda/device/gather_axis.cuh | 65 ++++ mlx/backend/cuda/device/indexing.cuh | 30 ++ mlx/backend/cuda/device/scatter.cuh | 68 ++++ mlx/backend/cuda/device/scatter_axis.cuh | 67 ++++ mlx/backend/cuda/device/scatter_ops.cuh | 44 +++ mlx/backend/cuda/indexing.cpp | 420 +++++++++++++++++++++++ mlx/backend/cuda/jit_module.cpp | 6 + mlx/backend/cuda/primitives.cu | 4 - 11 files changed, 830 insertions(+), 8 deletions(-) create mode 100644 mlx/backend/cuda/device/atomic_ops.cuh create mode 100644 mlx/backend/cuda/device/gather.cuh create mode 100644 mlx/backend/cuda/device/gather_axis.cuh create mode 100644 mlx/backend/cuda/device/indexing.cuh create mode 100644 mlx/backend/cuda/device/scatter.cuh create mode 100644 mlx/backend/cuda/device/scatter_axis.cuh create mode 100644 mlx/backend/cuda/device/scatter_ops.cuh create mode 100644 mlx/backend/cuda/indexing.cpp diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 89245da53..62ff8aaeb 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -1,8 +1,8 @@ # Filename rules in cuda backend: # # * Use .cu/.cuh if code contains device code, and .cpp/.h if not. -# * Device-only kernel code should be put in kernels/ subdir. -# * Files in kernels/ subdir should not include files outside. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu @@ -43,8 +44,8 @@ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) file( GLOB MLX_JIT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.h" - "${CMAKE_CURRENT_SOURCE_DIR}/kernels/*.cuh") + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) add_custom_command( OUTPUT gen/cuda_jit_sources.h diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh new file mode 100644 index 000000000..b6915606e --- /dev/null +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -0,0 +1,72 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" + +#include + +namespace mlx::core::cu { + +template +inline __device__ void atomic_add(T* out, T val) { + cuda::atomic_ref ref(*out); + ref += val; +} + +template +inline __device__ void atomic_prod(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old * val)) { + } +} + +template +inline __device__ void atomic_max(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_max(val); +} + +template +inline __device__ void atomic_min(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_min(val); +} + +// Somehow cuda::atomic_ref does not provide atomic add for following types. +template +inline __device__ void atomic_add_general(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old + val)) { + } +} + +inline __device__ void atomic_add(__half* out, __half val) { + atomicAdd(out, val); +} + +inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +#if __CUDA_ARCH__ < 900 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { +#if __CUDA_ARCH__ < 800 +#if CCCL_VERSION >= 2008000 + atomic_add_general(out, val); +#else + bool cccl_version_too_old_for_bfloat16_atomic_add = false; + assert(cccl_version_too_old_for_bfloat16_atomic_add); +#endif +#else + atomicAdd(out, val); +#endif +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather.cuh b/mlx/backend/cuda/device/gather.cuh new file mode 100644 index 000000000..7dbd84ac3 --- /dev/null +++ b/mlx/backend/cuda/device/gather.cuh @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template +__global__ void gather( + const T* src, + T* out, + LocT size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + const __grid_constant__ Shape slice_sizes, + uint32_t slice_size, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT out_idx = cg::this_grid().thread_rank(); + if (out_idx >= size) { + return; + } + + LocT src_elem = out_idx % slice_size; + LocT idx_elem = out_idx / slice_size; + + LocT src_loc = + elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather_axis.cuh b/mlx/backend/cuda/device/gather_axis.cuh new file mode 100644 index 000000000..f863b2d95 --- /dev/null +++ b/mlx/backend/cuda/device/gather_axis.cuh @@ -0,0 +1,65 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + int NDIM, + bool SrcC, + bool IdxC, + typename LocT> +__global__ void gather_axis( + const T* src, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array src_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), src_strides.data()); + } + + LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/indexing.cuh b/mlx/backend/cuda/device/indexing.cuh new file mode 100644 index 000000000..31cba1a90 --- /dev/null +++ b/mlx/backend/cuda/device/indexing.cuh @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +namespace mlx::core::cu { + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __host__ __device__ cuda::std::tuple +index_to_dims(T index, T dim1, T dim2) { + T x = index / (dim1 * dim2); + T y = (index % (dim1 * dim2)) / dim2; + T z = index % dim2; + return cuda::std::make_tuple(x, y, z); +} + +// Get absolute index from possible negative index. +template +inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (cuda::std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh new file mode 100644 index 000000000..b2f640350 --- /dev/null +++ b/mlx/backend/cuda/device/scatter.cuh @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + int IDX_NDIM, + typename LocT> +__global__ void scatter( + const T* upd, + T* out, + LocT size, + const __grid_constant__ Shape upd_shape, + const __grid_constant__ Strides upd_strides, + int32_t upd_ndim, + LocT upd_post_idx_size, + const __grid_constant__ Shape out_shape, + const __grid_constant__ Strides out_strides, + int32_t out_ndim, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT upd_idx = cg::this_grid().thread_rank(); + if (upd_idx >= size) { + return; + } + + LocT out_elem = upd_idx % upd_post_idx_size; + LocT idx_elem = upd_idx / upd_post_idx_size; + + LocT out_idx = elem_to_loc( + out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); + out_idx += idx_val * out_strides[axis]; + } + + LocT upd_loc = elem_to_loc( + out_elem + idx_elem * upd_post_idx_size, + upd_shape.data(), + upd_strides.data(), + upd_ndim); + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_axis.cuh b/mlx/backend/cuda/device/scatter_axis.cuh new file mode 100644 index 000000000..1f30f2ebd --- /dev/null +++ b/mlx/backend/cuda/device/scatter_axis.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NDIM, + bool UpdC, + bool IdxC, + typename LocT> +__global__ void scatter_axis( + const T* upd, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array upd_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), upd_strides.data()); + } + + LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_ops.cuh b/mlx/backend/cuda/device/scatter_ops.cuh new file mode 100644 index 000000000..d88f896ad --- /dev/null +++ b/mlx/backend/cuda/device/scatter_ops.cuh @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/atomic_ops.cuh" + +namespace mlx::core::cu { + +struct ScatterAssign { + template + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +struct ScatterSum { + template + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +struct ScatterProd { + template + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +struct ScatterMax { + template + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +struct ScatterMin { + template + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp new file mode 100644 index 000000000..3603605c4 --- /dev/null +++ b/mlx/backend/cuda/indexing.cpp @@ -0,0 +1,420 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include "cuda_jit_sources.h" + +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; + +void append_indices_arg( + cu::JitModule& mod, + const std::vector& inputs, + int nidx, + int idx_ndim) { + std::vector indices(nidx); + for (int i = 0; i < nidx; ++i) { + indices[i] = inputs[i + 1].data(); + } + mod.append_arg(std::move(indices)); + std::vector indices_shape(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].shape().begin(), + idx_ndim, + indices_shape.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_shape)); + std::vector indices_strides(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].strides().begin(), + idx_ndim, + indices_strides.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_strides)); +} + +} // namespace + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 0); + const auto& src = inputs[0]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + int nidx = inputs.size() - 1; + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t slice_size = std::accumulate( + slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); + + std::string module_name = fmt::format( + "gather_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_gather, std::move(kernel_names)); + }); + + mod.append_arg(src); + mod.append_arg(out); + if (large) { + mod.append_arg(out.size()); + } else { + mod.append_arg(out.size()); + } + mod.append_ndim_arg(src.shape()); + mod.append_ndim_arg(src.strides()); + mod.append_arg(src.ndim()); + mod.append_ndim_arg(slice_sizes_); + mod.append_arg(slice_size); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, out, large); + }); +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 1); + auto& upd = inputs.back(); + + // Copy src into out. + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + int nidx = axes_.size(); + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t upd_post_idx_size = std::accumulate( + upd.shape().begin() + idx_ndim, + upd.shape().end(), + 1, + std::multiplies()); + + const char* op = g_scatter_ops[reduce_type_]; + std::string module_name = fmt::format( + "scatter_{}_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + op, + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_scatter, std::move(kernel_names)); + }); + + mod.append_arg(upd); + mod.append_arg(out); + if (large) { + mod.append_arg(upd.size()); + } else { + mod.append_arg(upd.size()); + } + mod.append_ndim_arg(upd.shape()); + mod.append_ndim_arg(upd.strides()); + mod.append_arg(upd.ndim()); + if (large) { + mod.append_arg(upd_post_idx_size); + } else { + mod.append_arg(upd_post_idx_size); + } + mod.append_ndim_arg(out.shape()); + mod.append_ndim_arg(out.strides()); + mod.append_arg(out.ndim()); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, upd, large); + }); +} + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherAxis::eval_gpu"); + assert(inputs.size() > 1); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + std::string module_name = fmt::format( + "gather_axis_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype())); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(src); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(src.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(src.shape(axis_)); + mod.append_arg(src.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + src.ndim() - 1, + src.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ScatterAxis::eval_gpu"); + assert(inputs.size() > 2); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + const auto& upd = inputs[2]; + + // Copy src into out. + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; + std::string module_name = fmt::format( + "scatter_axis_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype()), + op); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(upd); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(upd.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(out.shape(axis_)); + mod.append_arg(upd.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + idx.ndim() - 1, + upd.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 3c00dd7f0..4a9e03841 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -148,11 +148,14 @@ bool compiler_supports_device_sass(Device& device) { #define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "indexing.cuh", + INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "utils.cuh", }; @@ -160,11 +163,14 @@ constexpr const char* g_include_names[] = { #undef INCLUDE_PREFIX constexpr const char* g_headers[] = { + jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, + jit_source_indexing, + jit_source_scatter_ops, jit_source_unary_ops, jit_source_utils, }; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 48b189626..7805e5c04 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -78,8 +78,6 @@ NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) -NO_GPU(Gather) -NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) @@ -89,8 +87,6 @@ NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) -NO_GPU(Scatter) -NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU_MULTI(SVD)