diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 9e63a269b..419f48789 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -267,7 +267,8 @@ void Compiled::eval_gpu( } } - return std::make_pair(std::move(builder.os), std::move(kernel_names)); + return std::make_tuple( + false, std::move(builder.os), std::move(kernel_names)); }); // Collapse contiguous dims to route to a faster kernel if possible. Also diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 31e8a0b67..c66fc56d4 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -142,17 +142,17 @@ std::string build_kernel( } // namespace -MetalKernelFunction metal_kernel( +CustomKernelFunction cuda_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, const std::string& source, - const std::string& header /* = "" */, - bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { + const std::string& header, + bool ensure_row_contiguous, + int shared_memory) { if (output_names.empty()) { throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); + "[custom_kernel] Must specify at least one output."); } std::vector shape_infos; @@ -177,21 +177,21 @@ MetalKernelFunction metal_kernel( StreamOrDevice s_ = {}) { if (inputs.size() != input_names.size()) { std::ostringstream msg; - msg << "[metal_kernel] Expected `inputs` to have size " + msg << "[custom_kernel] Expected `inputs` to have size " << input_names.size() << " but got size " << inputs.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_shapes.size() != output_names.size()) { std::ostringstream msg; - msg << "[metal_kernel] Expected `output_shapes` to have size " + msg << "[custom_kernel] Expected `output_shapes` to have size " << output_names.size() << " but got size " << output_shapes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); } if (output_dtypes.size() != output_names.size()) { std::ostringstream msg; - msg << "[metal_kernel] Expected `output_dtypes` to have size " + msg << "[custom_kernel] Expected `output_dtypes` to have size " << output_names.size() << " but got size " << output_dtypes.size() << "." << std::endl; throw std::invalid_argument(msg.str()); @@ -199,7 +199,7 @@ MetalKernelFunction metal_kernel( auto s = to_stream(s_); if (s.device != Device::gpu) { - throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + throw std::invalid_argument("[custom_kernel] Only supports the GPU."); } std::string kernel_name = @@ -237,12 +237,12 @@ MetalKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + shared_memory), std::move(inputs)); }; } -std::vector precompiled_custom_kernel( +std::vector precompiled_cuda_kernel( const std::string& name, const std::string& compiled_source, const std::vector& inputs, @@ -312,17 +312,14 @@ void CustomKernel::eval_gpu( // Compile the custom kernel std::string kernel_name = (is_precompiled_) ? name_ : "mlx::core::cu::" + name_; - auto get_module = [&]() -> cu::JitModule& { - if (is_precompiled_) { - return cu::get_jit_module( - s.device, name_, source_, std::vector{kernel_name}); - } else { - return cu::get_jit_module(s.device, name_, [&]() { - return std::make_pair(source_, std::vector{kernel_name}); - }); - } - }; - cu::JitModule& mod = get_module(); + cu::JitModule& mod = cu::get_jit_module( + s.device, + name_, + [&]() { + return std::make_tuple( + is_precompiled_, source_, std::vector{kernel_name}); + }, + false); // Make the arguments cu::KernelArgs args; diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index dd524a72d..829529609 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t")); } } - return std::make_pair(jit_source_gather, std::move(kernel_names)); + return std::make_tuple(false, jit_source_gather, std::move(kernel_names)); }); cu::KernelArgs args; @@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t")); } } - return std::make_pair(jit_source_scatter, std::move(kernel_names)); + return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); cu::KernelArgs args; @@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } } } - return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + return std::make_tuple( + false, jit_source_gather_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; @@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } } } - return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + return std::make_tuple( + false, jit_source_scatter_axis, std::move(kernel_names)); }); size_t idx_size_pre = 1; diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 367c94392..b188a7f0e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -217,85 +217,85 @@ constexpr const char* g_headers[] = { jit_source_utils, }; -} // namespace - -JitModule::JitModule( +void compile( Device& device, const std::string& module_name, - const KernelBuilder& builder) { - // Check cache. - std::vector ptx; - std::vector> ptx_kernels; - if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { - // Create program. - auto [source_code, kernel_names] = builder(); - nvrtcProgram prog; - CHECK_NVRTC_ERROR(nvrtcCreateProgram( - &prog, - source_code.c_str(), - (module_name + ".cu").c_str(), - std::size(g_headers), - g_headers, - g_include_names)); - std::unique_ptr prog_freer( - &prog, - [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); - for (const auto& name : kernel_names) { - CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); - } - - // Compile program. - std::vector args; - bool use_sass = compiler_supports_device_sass(device); - std::string compute = fmt::format( - "--gpu-architecture={}_{}{}", - use_sass ? "sm" : "compute", - device.compute_capability_major(), - device.compute_capability_minor()); - args.push_back(compute.c_str()); - std::string cccl_include = cccl_dir(); - if (!cccl_include.empty()) { - cccl_include = fmt::format("--include-path={}", cccl_include); - args.push_back(cccl_include.c_str()); - } - std::string cuda_include = - fmt::format("--include-path={}/include", cuda_home()); - args.push_back(cuda_include.c_str()); - nvrtcResult compile_result = - nvrtcCompileProgram(prog, args.size(), args.data()); - if (compile_result != NVRTC_SUCCESS) { - size_t log_size; - CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); - std::vector log(log_size + 1, 0); - CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); - throw std::runtime_error( - fmt::format("Failed to compile kernel: {}.", log.data())); - } - - // Get mangled names of kernel names. - for (const auto& name : kernel_names) { - const char* mangled; - CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); - ptx_kernels.emplace_back(name, mangled); - } - - // Get ptx data. - size_t ptx_size; - if (use_sass) { - CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); - } else { - CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); - } - ptx.resize(ptx_size, 0); - if (use_sass) { - CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); - } else { - CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); - } - write_cached_ptx( - ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); + const std::string& source, + const std::vector& kernel_names, + std::vector& ptx, + std::vector>& ptx_kernels) { + // Create the program + nvrtcProgram prog; + CHECK_NVRTC_ERROR(nvrtcCreateProgram( + &prog, + source.c_str(), + (module_name + ".cu").c_str(), + std::size(g_headers), + g_headers, + g_include_names)); + std::unique_ptr prog_freer( + &prog, + [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); + for (const auto& name : kernel_names) { + CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); } + // Compile program. + std::vector args; + bool use_sass = compiler_supports_device_sass(device); + std::string compute = fmt::format( + "--gpu-architecture={}_{}{}", + use_sass ? "sm" : "compute", + device.compute_capability_major(), + device.compute_capability_minor()); + args.push_back(compute.c_str()); + std::string cccl_include = cccl_dir(); + if (!cccl_include.empty()) { + cccl_include = fmt::format("--include-path={}", cccl_include); + args.push_back(cccl_include.c_str()); + } + std::string cuda_include = + fmt::format("--include-path={}/include", cuda_home()); + args.push_back(cuda_include.c_str()); + nvrtcResult compile_result = + nvrtcCompileProgram(prog, args.size(), args.data()); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size; + CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data())); + throw std::runtime_error( + fmt::format("Failed to compile kernel: {}.", log.data())); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled)); + ptx_kernels.emplace_back(name, mangled); + } + + // Get ptx data. + size_t ptx_size; + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size)); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); + } + ptx.resize(ptx_size, 0); + if (use_sass) { + CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); + } else { + CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); + } +} + +void load_module( + const std::string& module_name, + const std::vector& ptx, + const std::vector>& ptx_kernels, + CUmodule& module_, + std::unordered_map>& kernels) { // Load module. char jit_log[4089] = {}; CUjit_option options[] = { @@ -312,33 +312,44 @@ JitModule::JitModule( for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); - kernels_[name] = std::make_pair(kernel, false); + kernels[name] = std::make_pair(kernel, false); } } +} // namespace + JitModule::JitModule( Device& device, const std::string& module_name, - const std::string& ptx, - const std::vector& kernel_names) { - // Load module. - char jit_log[4089] = {}; - CUjit_option options[] = { - CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; - void* values[] = {jit_log, reinterpret_cast(std::size(jit_log) - 1)}; - CUresult jit_result = cuModuleLoadDataEx( - &module_, ptx.data(), std::size(options), options, values); - if (jit_result != CUDA_SUCCESS) { - throw std::runtime_error(fmt::format( - "Failed to load compiled {} kernel: {}.", module_name, jit_log)); + const KernelBuilder& builder, + bool cache) { + // Will hold the actual device executable source code and kernel names + std::vector ptx; + std::vector> ptx_kernels; + + // Try to load them from the file cache + if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { + auto [precompiled, source_code, kernel_names] = builder(); + + // Get the PTX or cubin + if (precompiled) { + ptx.insert(ptx.begin(), source_code.begin(), source_code.end()); + for (auto& name : kernel_names) { + ptx_kernels.emplace_back(name, name); + } + } else { + compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels); + } + + // If requested save them in the file cache for the next launch + if (cache) { + write_cached_ptx( + ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code); + } } - // Load kernels. - for (const auto& name : kernel_names) { - CUfunction kernel; - CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str())); - kernels_[name] = std::make_pair(kernel, false); - } + // Load the module + load_module(module_name, ptx, ptx_kernels, module_, kernels_); } JitModule::~JitModule() { @@ -372,25 +383,12 @@ std::unordered_map& get_jit_module_cache() { JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, - const KernelBuilder& builder) { + const KernelBuilder& builder, + bool cache) { auto& map = get_jit_module_cache(); auto it = map.find(name); if (it == map.end()) { - it = map.try_emplace(name, cu::device(device), name, builder).first; - } - return it->second; -} - -JitModule& get_jit_module( - const mlx::core::Device& device, - const std::string& name, - const std::string& ptx, - const std::vector& kernel_names) { - auto& map = get_jit_module_cache(); - auto it = map.find(name); - if (it == map.end()) { - it = map.try_emplace(name, cu::device(device), name, ptx, kernel_names) - .first; + it = map.try_emplace(name, cu::device(device), name, builder, cache).first; } return it->second; } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 1933bfccb..a4a75bb96 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -19,7 +19,8 @@ namespace mlx::core::cu { class Device; -using KernelBuilderResult = std::pair< +using KernelBuilderResult = std::tuple< + /* precompiled */ bool, /* source code */ std::string, /* kernel names */ std::vector>; using KernelBuilder = std::function; @@ -84,12 +85,8 @@ class JitModule { JitModule( Device& device, const std::string& module_name, - const KernelBuilder& builder); - JitModule( - Device& device, - const std::string& module_name, - const std::string& ptx, - const std::vector& kernel_names); + const KernelBuilder& builder, + bool cache); ~JitModule(); JitModule(const JitModule&) = delete; @@ -111,12 +108,7 @@ std::unordered_map& get_jit_module_cache(); JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, - const KernelBuilder& builder); - -JitModule& get_jit_module( - const mlx::core::Device& device, - const std::string& name, - const std::string& ptx, - const std::vector& kernel_names); + const KernelBuilder& builder, + bool cache = true); } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/no_cuda.cpp b/mlx/backend/cuda/no_cuda.cpp index 8a394c9e3..175a505b4 100644 --- a/mlx/backend/cuda/no_cuda.cpp +++ b/mlx/backend/cuda/no_cuda.cpp @@ -1,11 +1,47 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/cuda.h" +#include "mlx/fast.h" -namespace mlx::core::cu { +namespace mlx::core { + +namespace cu { bool is_available() { return false; } -} // namespace mlx::core::cu +} // namespace cu + +namespace fast { + +CustomKernelFunction cuda_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool, + int) { + throw std::runtime_error("[cuda_kernel] No CUDA back-end."); +} + +std::vector precompiled_cuda_kernel( + const std::string&, + const std::string&, + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&, + std::tuple, + std::tuple, + int shared_memory, + std::optional init_value, + bool ensure_row_contiguous, + StreamOrDevice) { + throw std::runtime_error("[cuda_kernel] No CUDA back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index 85e1080d5..9d6a73461 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -24,4 +24,19 @@ device_info() { } // namespace metal +namespace fast { + +CustomKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool, + bool) { + throw std::runtime_error("[metal_kernel] No Metal back-end."); +} + +} // namespace fast + } // namespace mlx::core diff --git a/mlx/fast.h b/mlx/fast.h index 05a185ecc..d154e4753 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -69,7 +69,7 @@ array affine_dequantize( using TemplateArg = std::variant; using ScalarArg = std::variant; -using MetalKernelFunction = std::function( +using CustomKernelFunction = std::function( const std::vector&, const std::vector&, const std::vector&, @@ -80,7 +80,7 @@ using MetalKernelFunction = std::function( bool, StreamOrDevice)>; -MetalKernelFunction metal_kernel( +CustomKernelFunction metal_kernel( const std::string& name, const std::vector& input_names, const std::vector& output_names, @@ -89,7 +89,16 @@ MetalKernelFunction metal_kernel( bool ensure_row_contiguous = true, bool atomic_outputs = false); -std::vector precompiled_custom_kernel( +CustomKernelFunction cuda_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + int shared_memory = 0); + +std::vector precompiled_cuda_kernel( const std::string& name, const std::string& compiled_source, const std::vector& inputs, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 7105d12cb..02e924a94 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -17,6 +17,66 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; +namespace { + +struct PyCustomKernelFunction { + PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag) + : kernel_(std::move(kernel)), tag_(tag) {} + + std::vector operator()( + const std::vector& inputs_, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::optional>>& + template_args_ = std::nullopt, + std::optional init_value = std::nullopt, + bool verbose = false, + mx::StreamOrDevice s = {}) const { + std::vector inputs; + for (const auto& value : inputs_) { + inputs.push_back(to_array(value, std::nullopt)); + } + std::vector> template_args; + if (template_args_) { + for (const auto& [name, value] : template_args_.value()) { + // Handle bool, int and dtype template args + if (nb::isinstance(value)) { + bool bool_val = nb::cast(value); + template_args.emplace_back(name, bool_val); + } else if (nb::isinstance(value)) { + int int_val = nb::cast(value); + template_args.emplace_back(name, int_val); + } else if (nb::isinstance(value)) { + mx::Dtype dtype = nb::cast(value); + template_args.emplace_back(name, dtype); + } else { + std::ostringstream msg; + msg << tag_ + << " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."; + throw std::invalid_argument(msg.str()); + } + } + } + return kernel_( + inputs, + output_shapes, + output_dtypes, + grid, + threadgroup, + template_args, + init_value, + verbose, + s); + } + + mx::fast::CustomKernelFunction kernel_; + const char* tag_; +}; + +} // namespace + void init_fast(nb::module_& parent_module) { auto m = parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); @@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) { ensure_row_contiguous, atomic_outputs); return nb::cpp_function( - [kernel = std::move(kernel)]( - const std::vector& inputs_, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::optional< - std::vector>>& - template_args_ = std::nullopt, - std::optional init_value = std::nullopt, - bool verbose = false, - mx::StreamOrDevice s = {}) { - std::vector inputs; - for (const auto& value : inputs_) { - inputs.push_back(to_array(value, std::nullopt)); - } - std::vector> - template_args; - if (template_args_) { - for (const auto& [name, value] : template_args_.value()) { - // Handle bool, int and dtype template args - if (nb::isinstance(value)) { - bool bool_val = nb::cast(value); - template_args.emplace_back(name, bool_val); - } else if (nb::isinstance(value)) { - int int_val = nb::cast(value); - template_args.emplace_back(name, int_val); - } else if (nb::isinstance(value)) { - mx::Dtype dtype = nb::cast(value); - template_args.emplace_back(name, dtype); - } else { - throw std::invalid_argument( - "[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); - } - } - } - return kernel( - inputs, - output_shapes, - output_dtypes, - grid, - threadgroup, - template_args, - init_value, - verbose, - s); - }, + PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), "inputs"_a, "output_shapes"_a, @@ -386,7 +400,123 @@ void init_fast(nb::module_& parent_module) { )pbdoc"); m.def( - "precompiled_custom_kernel", + "cuda_kernel", + [](const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + int shared_mem) { + auto kernel = mx::fast::cuda_kernel( + name, + input_names, + output_names, + source, + header, + ensure_row_contiguous, + shared_mem); + return nb::cpp_function( + PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"), + nb::kw_only(), + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "grid"_a, + "threadgroup"_a, + "template"_a = nb::none(), + "init_value"_a = nb::none(), + "verbose"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + R"pbdoc( + Run the kernel. + + Args: + inputs (List[array]): The inputs passed to the CUDA kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. + output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + + Returns: + List[array]: The list of output arrays.)pbdoc"); + }, + "name"_a, + "input_names"_a, + "output_names"_a, + "source"_a, + "header"_a = "", + "ensure_row_contiguous"_a = true, + "shared_memory"_a = 0, + R"pbdoc( + A jit-compiled custom CUDA kernel defined from a source string. + + This is the CUDA equivalent of :ref:`custom_metal_kernels`. + + Args: + name (str): Name for the kernel. + input_names (List[str]): The parameter names of the inputs in the + function signature. + output_names (List[str]): The parameter names of the outputs in the + function signature. + source (str): Source code. This is the body of a function in CUDA, + the function signature will be automatically generated. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of + the main function body. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + shared_memory (int): The dynamic shared memory to request for the + kernel. A value of 0 means no dynamic shared memory. Default: ``0``. + + Returns: + Callable ``cuda_kernel``. + + Example: + + .. code-block:: python + + def exp_elementwise(a: mx.array): + source = ''' + auto elem = cooperative_groups::this_grid().thread_rank(); + T tmp = inp[elem]; + out[elem] = exp(tmp); + ''' + + kernel = mx.fast.cuda_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source + ) + outputs = kernel( + inputs=[a], + template=[("T", mx.float32)], + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + verbose=True, + ) + return outputs[0] + + a = mx.random.normal(shape=(16, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) + )pbdoc"); + + m.def( + "precompiled_cuda_kernel", [](const std::string& name, const nb::bytes compiled_source, const std::vector& inputs_, @@ -420,14 +550,14 @@ void init_fast(nb::module_& parent_module) { std::string vtype_name = nb::cast(vtype.attr("__name__")); std::ostringstream msg; - msg << "[precompiled_custom_kernel] Invalid scalar argument type. " + msg << "[precompiled_cuda_kernel] Invalid scalar argument type. " << "Received " << vtype_name << " but must be one of bool, int or float"; throw std::invalid_argument(msg.str()); } } - return mx::fast::precompiled_custom_kernel( + return mx::fast::precompiled_cuda_kernel( name, std::string( static_cast(compiled_source.data()), @@ -457,5 +587,27 @@ void init_fast(nb::module_& parent_module) { "ensure_row_contiguous"_a = false, "stream"_a = nb::none(), R"pbdoc( + Run a precompiled CUDA kernel defined from PTX or cubin. + + This op is still experimental and various parts of the API may change. + + Args: + name (str): Name for the kernel + compiled_source (bytes): The precompiled kernel in raw bytes. + inputs (List[array]): The inputs passed to the CUDA kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output. + output_dtypes (List[Dtype]): The list of data types for each output. + scalars (List[Union[bool, int, float]]): A list of scalar arguments to + pass to the kernel. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + shared_memory (int): The dynamic shared memory to request for the + kernel. A value of 0 means no dynamic shared memory. Default: ``0``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. )pbdoc"); }