From e397177f6ec71dac7ac62bb543700145e2f5344e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Aug 2025 17:20:22 -0700 Subject: [PATCH] Custom cuda kernel (#2517) --- docs/src/python/cuda.rst | 9 + docs/src/python/fast.rst | 1 + mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/compiled.cpp | 3 +- mlx/backend/cuda/custom_kernel.cpp | 379 ++++++++++++++++++++++++++++ mlx/backend/cuda/indexing.cpp | 10 +- mlx/backend/cuda/jit_module.cpp | 221 +++++++++------- mlx/backend/cuda/jit_module.h | 21 +- mlx/backend/cuda/no_cuda.cpp | 40 ++- mlx/backend/cuda/primitives.cpp | 4 - mlx/backend/metal/custom_kernel.cpp | 7 +- mlx/backend/metal/no_metal.cpp | 8 +- mlx/fast.h | 33 ++- mlx/fast_primitives.h | 18 +- python/src/CMakeLists.txt | 1 + python/src/cuda.cpp | 19 ++ python/src/fast.cpp | 320 +++++++++++++++++++---- python/src/mlx.cpp | 2 + python/tests/test_fast.py | 156 ++++++++---- 19 files changed, 1042 insertions(+), 211 deletions(-) create mode 100644 docs/src/python/cuda.rst create mode 100644 mlx/backend/cuda/custom_kernel.cpp create mode 100644 python/src/cuda.cpp diff --git a/docs/src/python/cuda.rst b/docs/src/python/cuda.rst new file mode 100644 index 000000000..932d36b5e --- /dev/null +++ b/docs/src/python/cuda.rst @@ -0,0 +1,9 @@ +CUDA +===== + +.. currentmodule:: mlx.core.cuda + +.. autosummary:: + :toctree: _autosummary + + is_available diff --git a/docs/src/python/fast.rst b/docs/src/python/fast.rst index f78f40563..b250dcb18 100644 --- a/docs/src/python/fast.rst +++ b/docs/src/python/fast.rst @@ -13,3 +13,4 @@ Fast rope scaled_dot_product_attention metal_kernel + cuda_kernel diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 994307284..2e12c8c3e 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 9e63a269b..419f48789 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -267,7 +267,8 @@ void Compiled::eval_gpu( } } - return std::make_pair(std::move(builder.os), std::move(kernel_names)); + return std::make_tuple( + false, std::move(builder.os), std::move(kernel_names)); }); // Collapse contiguous dims to route to a faster kernel if possible. Also diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp new file mode 100644 index 000000000..ee1778fd8 --- /dev/null +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -0,0 +1,379 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/fast.h" +#include "mlx/fast_primitives.h" + +#include +#include + +namespace mlx::core::fast { + +namespace { + +constexpr const char* default_header = R"( +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +#define inf cuda::std::numeric_limits::infinity() + +)"; + +std::string template_arguments_hash( + const std::vector>& template_args) { + if (template_args.empty()) { + return ""; + } + + std::string hash; + hash.reserve(512); + + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + hash += fmt::format("_{}", std::get(arg)); + } else if (std::holds_alternative(arg)) { + hash += (std::get(arg)) ? "_t" : "_f"; + } else if (std::holds_alternative(arg)) { + hash += "_"; + hash += get_type_string(std::get(arg)); + } + } + + return hash; +} + +std::string build_kernel( + const std::string& func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& shape_infos) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 8192); + kernel_source += default_header; + kernel_source += header; + kernel_source += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + + kernel_source += "__global__ void "; + kernel_source += func_name; + kernel_source += "(\n"; + + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + kernel_source += " const "; + kernel_source += dtype_to_cuda_type(arr.dtype()); + kernel_source += "* "; + kernel_source += name; + kernel_source += ",\n"; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += " const __grid_constant__ Shape "; + kernel_source += name; + kernel_source += "_shape,\n"; + } + if (shape_infos[i].strides) { + kernel_source += " const __grid_constant__ Strides "; + kernel_source += name; + kernel_source += "_strides,\n"; + } + if (shape_infos[i].ndim) { + kernel_source += " const __grid_constant__ int "; + kernel_source += name; + kernel_source += "_ndim,\n"; + } + } + } + + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " "; + kernel_source += dtype_to_cuda_type(dtype); + kernel_source += "* "; + kernel_source += name; + if (i < output_names.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + } + + // Set compile time constants + if (!template_args.empty()) { + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + kernel_source += + fmt::format(" constexpr int {} = {};\n", name, std::get(arg)); + } else if (std::holds_alternative(arg)) { + kernel_source += fmt::format( + " constexpr bool {} = {};\n", name, std::get(arg)); + } else { + kernel_source += fmt::format( + " using {} = {};\n", + name, + dtype_to_cuda_type(std::get(arg))); + } + } + kernel_source += "\n"; + } + + kernel_source += source; + kernel_source += "\n}\n\n} // namespace mlx::core::cu\n"; + + return kernel_source; +} + +} // namespace + +CustomKernelFunction cuda_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + int shared_memory) { + if (output_names.empty()) { + throw std::invalid_argument( + "[custom_kernel] Must specify at least one output."); + } + + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + + return [=, shape_infos = std::move(shape_infos)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[custom_kernel] Only supports the GPU."); + } + + std::string kernel_name = + "custom_kernel_" + name + template_arguments_hash(template_args); + std::string kernel_source = build_kernel( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + shape_infos); + + if (verbose) { + std::cout << "Generated source code for `" << kernel_name + << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value, + std::vector{}, + false, + shared_memory), + std::move(inputs)); + }; +} + +std::vector precompiled_cuda_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory, + std::optional init_value, + bool ensure_row_contiguous, + StreamOrDevice s) { + std::vector shape_infos( + inputs.size(), CustomKernelShapeInfo{false, false, false}); + return array::make_arrays( + output_shapes, + output_dtypes, + std::make_shared( + to_stream(s), + name, + compiled_source, + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value, + scalars, + true, + shared_memory), + inputs); +} + +void CustomKernel::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("CustomKernel::eval_gpu"); + auto& s = stream(); + + std::vector copies; + + // Allocate and initialize the output arrays + for (auto& out : outputs) { + if (init_value_) { + copies.emplace_back(init_value_.value(), out.dtype()); + fill_gpu(copies.back(), out, s); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + } + + // Create the input arrays and copy if needed + auto check_input = [&copies, &s, this](const array& x) -> const array { + bool no_copy = x.flags().row_contiguous; + if (!ensure_row_contiguous_ || no_copy) { + return x; + } else { + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); + } + }; + std::vector checked_inputs; + for (const array& in : inputs) { + checked_inputs.push_back(check_input(in)); + } + + // Compile the custom kernel + std::string kernel_name = + (is_precompiled_) ? name_ : "mlx::core::cu::" + name_; + cu::JitModule& mod = cu::get_jit_module( + s.device, + name_, + [&]() { + return std::make_tuple( + is_precompiled_, source_, std::vector{kernel_name}); + }, + false); + + // Make the arguments + cu::KernelArgs args; + for (int i = 0; i < checked_inputs.size(); i++) { + const array& in = checked_inputs[i]; + auto& shape_info = shape_infos_[i]; + args.append(in); + if (shape_info.shape) { + args.append_ndim(in.shape()); + } + if (shape_info.strides) { + args.append_ndim(in.strides()); + } + if (shape_info.ndim) { + args.append(in.ndim()); + } + } + for (auto& out : outputs) { + args.append(out); + } + for (auto& s : scalar_arguments_) { + if (std::holds_alternative(s)) { + args.append(std::get(s)); + } else if (std::holds_alternative(s)) { + args.append(std::get(s)); + } else if (std::holds_alternative(s)) { + args.append(std::get(s)); + } + } + + // Make the grid + const auto [tx, ty, tz] = threadgroup_; + const auto [gx, gy, gz] = grid_; + dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); + dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); + + // Call the kernel + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : checked_inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + for (const auto& t : copies) { + encoder.add_temporary(t); + } + auto kernel = + mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) { + if (smem > 0 && smem > 48000) { + cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem); + } + }); + encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args()); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index dd524a72d..829529609 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t")); } } - return std::make_pair(jit_source_gather, std::move(kernel_names)); + return std::make_tuple(false, jit_source_gather, std::move(kernel_names)); }); cu::KernelArgs args; @@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t")); } } - return std::make_pair(jit_source_scatter, std::move(kernel_names)); + return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); cu::KernelArgs args; @@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } } } - return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + return std::make_tuple( + false, jit_source_gather_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; @@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } } } - return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + return std::make_tuple( + false, jit_source_scatter_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 25db207e3..531052d46 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() { bool read_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, - std::vector* ptx, - std::vector>* ptx_kernels) { + std::string& ptx, + std::vector>& ptx_kernels) { if (cache_dir.empty()) { return false; } @@ -117,15 +117,15 @@ bool read_cached_ptx( if (!ptx_file.good()) { return false; } - ptx->resize(ptx_size); - ptx_file.read(ptx->data(), ptx_size); + ptx.resize(ptx_size); + ptx_file.read(ptx.data(), ptx_size); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::string line; while (std::getline(txt_file, line)) { auto tab = line.find('\t'); if (tab != std::string::npos) { - ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); + ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1)); } } return true; @@ -135,7 +135,7 @@ bool read_cached_ptx( void write_cached_ptx( const std::filesystem::path& cache_dir, const std::string& module_name, - const std::vector& ptx, + const std::string& ptx, const std::vector>& ptx_kernels, const std::string& source_code) { if (cache_dir.empty()) { @@ -217,85 +217,85 @@ constexpr const char* g_headers[] = { jit_source_utils, }; -} // namespace - -JitModule::JitModule( +void compile( Device& device, const std::string& module_name, - const KernelBuilder& builder) { - // Check cache. - std::vector ptx; - std::vector> ptx_kernels; - if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { - // Create program. - auto [source_code, kernel_names] = builder(); - nvrtcProgram prog; - CHECK_NVRTC_ERROR(nvrtcCreateProgram( - &prog, - source_code.c_str(), - (module_name + ".cu").c_str(), - std::size(g_headers), - g_headers, - g_include_names)); - std::unique_ptr prog_freer( - &prog, - [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); - for (const auto& name : kernel_names) { - CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); - } - - // Compile program. - std::vector args; - bool use_sass = compiler_supports_device_sass(device); - std::string compute = fmt::format( - "--gpu-architecture={}_{}{}", - use_sass ? "sm" : "compute", - device.compute_capability_major(), - device.compute_capability_minor()); - args.push_back(compute.c_str()); - std::string cccl_include = cccl_dir(); - if (!cccl_include.empty()) { - cccl_include = fmt::format("--include-path={}", cccl_include); - args.push_back(cccl_include.c_str()); - } - std::string cuda_include = - fmt::format("--include-path={}/include", cuda_home()); - args.push_back(cuda_include.c_str()); - nvrtcResult compile_result = - nvrtcCompileProgram(prog, args.size(), args.data()); - if (compile_result != NVRTC_SUCCESS) { - size_t log_size; - CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); - std::vector log(log_size + 1, 0); - CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); - throw std::runtime_error( - fmt::format("Failed to compile kernel: {}.", log.data())); - } - - // Get mangled names of kernel names. - for (const auto& name : kernel_names) { - const char* mangled; - CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); - ptx_kernels.emplace_back(name, mangled); - } - - // Get ptx data. - size_t ptx_size; - if (use_sass) { - CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); - } else { - CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); - } - ptx.resize(ptx_size, 0); - if (use_sass) { - CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); - } else { - CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); - } - write_cached_ptx( - ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); + const std::string& source, + const std::vector& kernel_names, + std::string& ptx, + std::vector>& ptx_kernels) { + // Create the program + nvrtcProgram prog; + CHECK_NVRTC_ERROR(nvrtcCreateProgram( + &prog, + source.c_str(), + (module_name + ".cu").c_str(), + std::size(g_headers), + g_headers, + g_include_names)); + std::unique_ptr prog_freer( + &prog, + [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); + for (const auto& name : kernel_names) { + CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); } + // Compile program. + std::vector args; + bool use_sass = compiler_supports_device_sass(device); + std::string compute = fmt::format( + "--gpu-architecture={}_{}{}", + use_sass ? "sm" : "compute", + device.compute_capability_major(), + device.compute_capability_minor()); + args.push_back(compute.c_str()); + std::string cccl_include = cccl_dir(); + if (!cccl_include.empty()) { + cccl_include = fmt::format("--include-path={}", cccl_include); + args.push_back(cccl_include.c_str()); + } + std::string cuda_include = + fmt::format("--include-path={}/include", cuda_home()); + args.push_back(cuda_include.c_str()); + nvrtcResult compile_result = + nvrtcCompileProgram(prog, args.size(), args.data()); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size; + CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); + throw std::runtime_error( + fmt::format("Failed to compile kernel: {}.", log.data())); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); + ptx_kernels.emplace_back(name, mangled); + } + + // Get ptx data. + size_t ptx_size; + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); + } + ptx.resize(ptx_size); + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); + } +} + +void load_module( + const std::string& module_name, + const std::string& ptx, + const std::vector>& ptx_kernels, + CUmodule& module_, + std::unordered_map>& kernels) { // Load module. char jit_log[4089] = {}; CUjit_option options[] = { @@ -312,21 +312,69 @@ JitModule::JitModule( for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); - kernels_[name] = kernel; + kernels[name] = std::make_pair(kernel, false); } } +} // namespace + +JitModule::JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder, + bool use_disk_cache) { + // Will hold the actual device executable source code and kernel names + std::string ptx; + std::vector> ptx_kernels; + + // Try to load them from the file cache + if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) { + auto [precompiled, source_code, kernel_names] = builder(); + + // Get the PTX or cubin + if (precompiled) { + ptx = std::move(source_code); + for (auto& name : kernel_names) { + ptx_kernels.emplace_back(name, name); + } + } else { + compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels); + } + + // If requested save them in the file cache for the next launch + if (use_disk_cache) { + write_cached_ptx( + ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); + } + } + + // Load the module + load_module(module_name, ptx, ptx_kernels, module_, kernels_); +} + JitModule::~JitModule() { CHECK_CUDA_ERROR(cuModuleUnload(module_)); } -CUfunction JitModule::get_kernel(const std::string& kernel_name) { +CUfunction JitModule::get_kernel( + const std::string& kernel_name, + std::function configure_kernel) { auto it = kernels_.find(kernel_name); if (it == kernels_.end()) { throw std::runtime_error( fmt::format("There is no kernel named {}.", kernel_name)); } - return it->second; + + // If it is the first time we run this kernel then configure it. Do it only + // once! + if (!it->second.second) { + if (configure_kernel) { + configure_kernel(it->second.first); + } + it->second.second = true; + } + + return it->second.first; } std::unordered_map& get_jit_module_cache() { @@ -337,11 +385,12 @@ std::unordered_map& get_jit_module_cache() { JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, - const KernelBuilder& builder) { + const KernelBuilder& builder, + bool cache) { auto& map = get_jit_module_cache(); auto it = map.find(name); if (it == map.end()) { - it = map.try_emplace(name, cu::device(device), name, builder).first; + it = map.try_emplace(name, cu::device(device), name, builder, cache).first; } return it->second; } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index df3f58352..d919f9bc0 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -19,7 +19,8 @@ namespace mlx::core::cu { class Device; -using KernelBuilderResult = std::pair< +using KernelBuilderResult = std::tuple< + /* precompiled */ bool, /* source code */ std::string, /* kernel names */ std::vector>; using KernelBuilder = std::function; @@ -63,14 +64,16 @@ struct KernelArgs { private: std::vector args_; - // The cuLaunchKernel API requires passing pointers to arguments so store - // temporary values untill kernel is launched. + // The cuGraphAddKernelNode API requires passing pointers to arguments so + // store temporary values until the node is created. using Arg = std::variant< std::monostate, CUdeviceptr, + bool, int32_t, uint32_t, int64_t, + float, SmallVector, SmallVector, SmallVector>; @@ -82,16 +85,19 @@ class JitModule { JitModule( Device& device, const std::string& module_name, - const KernelBuilder& builder); + const KernelBuilder& builder, + bool cache); ~JitModule(); JitModule(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete; - CUfunction get_kernel(const std::string& kernel_name); + CUfunction get_kernel( + const std::string& kernel_name, + std::function configure_kernel = nullptr); private: CUmodule module_{nullptr}; - std::unordered_map kernels_; + std::unordered_map> kernels_; }; std::unordered_map& get_jit_module_cache(); @@ -99,6 +105,7 @@ std::unordered_map& get_jit_module_cache(); JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, - const KernelBuilder& builder); + const KernelBuilder& builder, + bool use_disk_cache = true); } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp index 8a394c9e3..175a505b4 100644 --- a/mlx/backend/cuda/no_cuda.cpp +++ b/mlx/backend/cuda/no_cuda.cpp @@ -1,11 +1,47 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cuda.h" +#include "mlx/fast.h" -namespace mlx::core::cu { +namespace mlx::core { + +namespace cu { bool is_available() { return false; } -} // namespace mlx::core::cu +} // namespace cu + +namespace fast { + +CustomKernelFunction cuda_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool, + int) { + throw std::runtime_error("[cuda_kernel] No CUDA back-end."); +} + +std::vector precompiled_cuda_kernel( + const std::string&, + const std::string&, + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&, + std::tuple, + std::tuple, + int shared_memory, + std::optional init_value, + bool ensure_row_contiguous, + StreamOrDevice) { + throw std::runtime_error("[cuda_kernel] No CUDA back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index f5a61366c..aa20f0128 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -41,10 +41,6 @@ NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) -namespace fast { -NO_GPU_MULTI(CustomKernel) -} // namespace fast - namespace distributed { NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 161503a0e..41e399ce3 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -172,7 +172,7 @@ std::string write_template( return template_def.str(); } -MetalKernelFunction metal_kernel( +CustomKernelFunction metal_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, @@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel( threadgroup, shape_infos, ensure_row_contiguous, - init_value), + init_value, + std::vector{}, + false, + 0), std::move(inputs)); }; } diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index 9785e07c2..9d6a73461 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -26,15 +26,15 @@ device_info() { namespace fast { -MetalKernelFunction metal_kernel( +CustomKernelFunction metal_kernel( const std::string&, const std::vector&, const std::vector&, const std::string&, const std::string&, - bool ensure_row_contiguous, - bool atomic_outputs) { - throw std::runtime_error("[metal_kernel] No GPU back-end."); + bool, + bool) { + throw std::runtime_error("[metal_kernel] No Metal back-end."); } } // namespace fast diff --git a/mlx/fast.h b/mlx/fast.h index 7aebe3863..d154e4753 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -66,9 +66,10 @@ array affine_dequantize( int bits = 4, StreamOrDevice s = {}); -typedef std::variant TemplateArg; +using TemplateArg = std::variant; +using ScalarArg = std::variant; -typedef std::function( +using CustomKernelFunction = std::function( const std::vector&, const std::vector&, const std::vector&, @@ -77,10 +78,9 @@ typedef std::function( std::vector>, std::optional, bool, - StreamOrDevice)> - MetalKernelFunction; + StreamOrDevice)>; -MetalKernelFunction metal_kernel( +CustomKernelFunction metal_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, @@ -89,4 +89,27 @@ MetalKernelFunction metal_kernel( bool ensure_row_contiguous = true, bool atomic_outputs = false); +CustomKernelFunction cuda_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + int shared_memory = 0); + +std::vector precompiled_cuda_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory = 0, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + StreamOrDevice s = {}); + } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 52135adad..e0e83f726 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include #include "mlx/primitives.h" @@ -283,6 +284,8 @@ struct CustomKernelShapeInfo { bool ndim = false; }; +using ScalarArg = std::variant; + class CustomKernel : public Primitive { public: CustomKernel( @@ -293,7 +296,10 @@ class CustomKernel : public Primitive { std::tuple threadgroup, std::vector shape_infos, bool ensure_row_contiguous, - std::optional init_value) + std::optional init_value, + std::vector scalar_arguments, + bool is_precompiled, + int shared_memory) : Primitive(stream), source_(std::move(source)), name_(std::move(name)), @@ -301,11 +307,14 @@ class CustomKernel : public Primitive { threadgroup_(threadgroup), shape_infos_(std::move(shape_infos)), ensure_row_contiguous_(ensure_row_contiguous), - init_value_(init_value) {} + init_value_(init_value), + scalar_arguments_(std::move(scalar_arguments)), + is_precompiled_(is_precompiled), + shared_memory_(shared_memory) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { - throw std::runtime_error("Custom Metal kernels only run on GPU."); + throw std::runtime_error("Custom kernels only run on GPU."); } void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -321,6 +330,9 @@ class CustomKernel : public Primitive { std::vector shape_infos_; bool ensure_row_contiguous_; std::optional init_value_; + std::vector scalar_arguments_; + bool is_precompiled_; + int shared_memory_; }; } // namespace mlx::core::fast diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 29beca859..f094fdfe8 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -17,6 +17,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp diff --git a/python/src/cuda.cpp b/python/src/cuda.cpp new file mode 100644 index 000000000..13b3a0154 --- /dev/null +++ b/python/src/cuda.cpp @@ -0,0 +1,19 @@ +// Copyright © 2023-2025 Apple Inc. + +#include + +#include "mlx/backend/cuda/cuda.h" + +namespace mx = mlx::core; +namespace nb = nanobind; + +void init_cuda(nb::module_& m) { + nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda"); + + cuda.def( + "is_available", + &mx::cu::is_available, + R"pbdoc( + Check if the CUDA back-end is available. + )pbdoc"); +} diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 3d0bc4147..12d6de358 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -17,6 +17,66 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +namespace { + +struct PyCustomKernelFunction { + PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag) + : kernel_(std::move(kernel)), tag_(tag) {} + + std::vector operator()( + const std::vector& inputs_, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::optional>>& + template_args_ = std::nullopt, + std::optional init_value = std::nullopt, + bool verbose = false, + mx::StreamOrDevice s = {}) const { + std::vector inputs; + for (const auto& value : inputs_) { + inputs.push_back(to_array(value, std::nullopt)); + } + std::vector> template_args; + if (template_args_) { + for (const auto& [name, value] : template_args_.value()) { + // Handle bool, int and dtype template args + if (nb::isinstance(value)) { + bool bool_val = nb::cast(value); + template_args.emplace_back(name, bool_val); + } else if (nb::isinstance(value)) { + int int_val = nb::cast(value); + template_args.emplace_back(name, int_val); + } else if (nb::isinstance(value)) { + mx::Dtype dtype = nb::cast(value); + template_args.emplace_back(name, dtype); + } else { + std::ostringstream msg; + msg << tag_ + << " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."; + throw std::invalid_argument(msg.str()); + } + } + } + return kernel_( + inputs, + output_shapes, + output_dtypes, + grid, + threadgroup, + template_args, + init_value, + verbose, + s); + } + + mx::fast::CustomKernelFunction kernel_; + const char* tag_; +}; + +} // namespace + void init_fast(nb::module_& parent_module) { auto m = parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); @@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) { ensure_row_contiguous, atomic_outputs); return nb::cpp_function( - [kernel = std::move(kernel)]( - const std::vector& inputs_, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::optional< - std::vector>>& - template_args_ = std::nullopt, - std::optional init_value = std::nullopt, - bool verbose = false, - mx::StreamOrDevice s = {}) { - std::vector inputs; - for (const auto& value : inputs_) { - inputs.push_back(to_array(value, std::nullopt)); - } - std::vector> - template_args; - if (template_args_) { - for (const auto& [name, value] : template_args_.value()) { - // Handle bool, int and dtype template args - if (nb::isinstance(value)) { - bool bool_val = nb::cast(value); - template_args.emplace_back(name, bool_val); - } else if (nb::isinstance(value)) { - int int_val = nb::cast(value); - template_args.emplace_back(name, int_val); - } else if (nb::isinstance(value)) { - mx::Dtype dtype = nb::cast(value); - template_args.emplace_back(name, dtype); - } else { - throw std::invalid_argument( - "[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); - } - } - } - return kernel( - inputs, - output_shapes, - output_dtypes, - grid, - threadgroup, - template_args, - init_value, - verbose, - s); - }, + PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), "inputs"_a, "output_shapes"_a, @@ -384,4 +398,216 @@ void init_fast(nb::module_& parent_module) { b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) )pbdoc"); + + m.def( + "cuda_kernel", + [](const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + int shared_mem) { + auto kernel = mx::fast::cuda_kernel( + name, + input_names, + output_names, + source, + header, + ensure_row_contiguous, + shared_mem); + return nb::cpp_function( + PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"), + nb::kw_only(), + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "grid"_a, + "threadgroup"_a, + "template"_a = nb::none(), + "init_value"_a = nb::none(), + "verbose"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + R"pbdoc( + Run the kernel. + + Args: + inputs (List[array]): The inputs passed to the CUDA kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. + output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + + Returns: + List[array]: The list of output arrays.)pbdoc"); + }, + "name"_a, + "input_names"_a, + "output_names"_a, + "source"_a, + "header"_a = "", + "ensure_row_contiguous"_a = true, + "shared_memory"_a = 0, + R"pbdoc( + A jit-compiled custom CUDA kernel defined from a source string. + + This is the CUDA equivalent of :ref:`custom_metal_kernels`. + + Args: + name (str): Name for the kernel. + input_names (List[str]): The parameter names of the inputs in the + function signature. + output_names (List[str]): The parameter names of the outputs in the + function signature. + source (str): Source code. This is the body of a function in CUDA, + the function signature will be automatically generated. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of + the main function body. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + shared_memory (int): The dynamic shared memory to request for the + kernel. A value of 0 means no dynamic shared memory. Default: ``0``. + + Returns: + Callable ``cuda_kernel``. + + Example: + + .. code-block:: python + + def exp_elementwise(a: mx.array): + source = ''' + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = inp[elem]; + out[elem] = exp(tmp); + ''' + + kernel = mx.fast.cuda_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source + ) + outputs = kernel( + inputs=[a], + template=[("T", mx.float32)], + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + verbose=True, + ) + return outputs[0] + + a = mx.random.normal(shape=(16, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) + )pbdoc"); + + m.def( + "precompiled_cuda_kernel", + [](const std::string& name, + const nb::bytes compiled_source, + const std::vector& inputs_, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars_, + std::tuple grid, + std::tuple threadgroup, + int shared_memory, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + mx::StreamOrDevice s = {}) { + // Collect the inputs and cast them to array + std::vector inputs; + for (const auto& value : inputs_) { + inputs.push_back(to_array(value, std::nullopt)); + } + + // Collect the scalar inputs + std::vector scalars; + scalars.reserve(scalars_.size()); + for (const auto& v : scalars_) { + if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else if (nb::isinstance(v)) { + scalars.push_back(nb::cast(v)); + } else { + nb::object vtype = v.attr("__class__"); + std::string vtype_name = + nb::cast(vtype.attr("__name__")); + std::ostringstream msg; + msg << "[precompiled_cuda_kernel] Invalid scalar argument type. " + << "Received " << vtype_name + << " but must be one of bool, int or float"; + throw std::invalid_argument(msg.str()); + } + } + + return mx::fast::precompiled_cuda_kernel( + name, + std::string( + static_cast(compiled_source.data()), + compiled_source.size()), + inputs, + output_shapes, + output_dtypes, + scalars, + grid, + threadgroup, + shared_memory, + init_value, + ensure_row_contiguous, + s); + }, + nb::kw_only(), + "name"_a, + "compiled_source"_a, + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "scalars"_a, + "grid"_a, + "threadgroup"_a, + "shared_memory"_a = 0, + "init_value"_a = nb::none(), + "ensure_row_contiguous"_a = false, + "stream"_a = nb::none(), + R"pbdoc( + Run a precompiled CUDA kernel defined from PTX or cubin. + + This op is still experimental and various parts of the API may change. + + Args: + name (str): Name for the kernel + compiled_source (bytes): The precompiled kernel in raw bytes. + inputs (List[array]): The inputs passed to the CUDA kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output. + output_dtypes (List[Dtype]): The list of data types for each output. + scalars (List[Union[bool, int, float]]): A list of scalar arguments to + pass to the kernel. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + shared_memory (int): The dynamic shared memory to request for the + kernel. A value of 0 means no dynamic shared memory. Default: ``0``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + )pbdoc"); } diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index eaddecb26..d89e48300 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -12,6 +12,7 @@ void init_array(nb::module_&); void init_device(nb::module_&); void init_stream(nb::module_&); void init_metal(nb::module_&); +void init_cuda(nb::module_&); void init_memory(nb::module_&); void init_ops(nb::module_&); void init_transforms(nb::module_&); @@ -35,6 +36,7 @@ NB_MODULE(core, m) { init_stream(m); init_array(m); init_metal(m); + init_cuda(m); init_memory(m); init_ops(m); init_transforms(m); diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index f79a62a15..5aabaf388 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase): )(x) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_basic(self): + if mx.metal.is_available(): + source = """ + uint elem = thread_position_in_grid.x; + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + out1[elem] = a[elem]; + """ + custom_kernel = mx.fast.cuda_kernel + mx.random.seed(7) a = mx.random.normal(shape=(2, 2)) - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="basic", input_names=["a"], output_names=["out1"], - source=""" - uint elem = thread_position_in_grid.x; - out1[elem] = a[elem]; - """, + source=source, ) out = kernel( inputs=[a], @@ -604,17 +614,10 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out[0], a)) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_args(self): - mx.random.seed(7) - a = mx.random.normal(shape=(3, 6)) - c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16) - - kernel = mx.fast.metal_kernel( - name="arg_test", - input_names=["a", "b", "c", "d"], - output_names=["out1", "out2"], - source=""" + if mx.metal.is_available(): + source = """ uint elem = thread_position_in_grid.x; T tmp = a[0]; if (e) { @@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase): out1[elem] = 1; } out2[elem] = a[1] + b[2] + c[1] - d; - """, + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = a[0]; + if (e) { + out1[elem] = a[1] + b[2] + static_cast(c[3]) + d[0] + f; + } else { + out1[elem] = 1; + } + out2[elem] = a[1] + b[2] + static_cast(c[1]) - d[0]; + """ + custom_kernel = mx.fast.cuda_kernel + + mx.random.seed(7) + a = mx.random.normal(shape=(3, 6)) + c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16) + + kernel = custom_kernel( + name="arg_test", + input_names=["a", "b", "c", "d"], + output_names=["out1", "out2"], + source=source, ) out = kernel( inputs=[ @@ -647,27 +673,43 @@ class TestFast(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_strides(self): + if mx.metal.is_available(): + source = """ + uint elem = thread_position_in_grid.x; + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; + """ + source_contig = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim); + T tmp = inp[loc]; + out[elem] = exp(tmp) * WARP_SIZE; + """ + source_contig = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = inp[elem]; + out[elem] = exp(tmp) * WARP_SIZE; + """ + custom_kernel = mx.fast.cuda_kernel + mx.random.seed(7) a = mx.random.normal(shape=(3, 6)) - source = """ - uint elem = thread_position_in_grid.x; - uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); - T tmp = inp[loc]; - out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; - """ - source_contig = """ - uint elem = thread_position_in_grid.x; - T tmp = inp[elem]; - out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; - """ # non contiguous a = mx.tile(a[::2], [4, 1]) for contig in [True, False]: - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="myexp" + str(contig), input_names=["inp"], output_names=["out"], @@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_helper(self): - mx.random.seed(7) - a = mx.random.normal(shape=(2, 2)) - kernel = mx.fast.metal_kernel( - name="helper", - input_names=["a"], - output_names=["out1"], - header=""" + if mx.metal.is_available(): + header = """ template T do_exp(T x) { return metal::precise::exp(x); } - """, - source=""" + """ + source = """ uint elem = thread_position_in_grid.x; out1[elem] = do_exp(a[elem]); - """, + """ + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + header = """ + template + __device__ T do_exp(T x) { + return exp(x); + } + """ + source = """ + auto elem = cooperative_groups::this_grid().thread_rank(); + out1[elem] = do_exp(a[elem]); + """ + custom_kernel = mx.fast.cuda_kernel + + mx.random.seed(7) + a = mx.random.normal(shape=(2, 2)) + kernel = custom_kernel( + name="helper", + input_names=["a"], + output_names=["out1"], + header=header, + source=source, ) out = kernel( inputs=[a], @@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out[0], mx.exp(a))) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available") def test_custom_kernel_attributes(self): + if mx.metal.is_available(): + source = "out[0] = threads_per_threadgroup.x;" + custom_kernel = mx.fast.metal_kernel + elif mx.cuda.is_available(): + source = "out[0] = blockDim.x;" + custom_kernel = mx.fast.cuda_kernel + a = mx.zeros(shape=(1, 1)) - kernel = mx.fast.metal_kernel( + kernel = custom_kernel( name="test_fun", input_names=["a"], output_names=["out"], - source=""" - out[0] = threads_per_threadgroup.x; - """, + source=source, ) out = kernel( inputs=[a],