diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 994307284..2e12c8c3e 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp new file mode 100644 index 000000000..3aeebbe9a --- /dev/null +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -0,0 +1,320 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/fast.h" +#include "mlx/fast_primitives.h" + +#include +#include + +namespace mlx::core::fast { + +namespace { + +constexpr const char* default_header = R"( +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +#define inf cuda::std::numeric_limits::infinity() + +)"; + +std::string template_arguments_hash( + const std::vector>& template_args) { + if (template_args.empty()) { + return ""; + } + + std::string hash; + hash.reserve(512); + + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + hash += fmt::format("_{}", std::get(arg)); + } else if (std::holds_alternative(arg)) { + hash += (std::get(arg)) ? "_t" : "_f"; + } else if (std::holds_alternative(arg)) { + hash += "_"; + hash += get_type_string(std::get(arg)); + } + } + + return hash; +} + +std::string build_kernel( + std::string func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& shape_infos) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 8192); + kernel_source += default_header; + kernel_source += header; + kernel_source += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + + kernel_source += "__global__ void "; + kernel_source += func_name; + kernel_source += "(\n"; + + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + kernel_source += " const "; + kernel_source += dtype_to_cuda_type(arr.dtype()); + kernel_source += "* "; + kernel_source += name; + kernel_source += ",\n"; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += " const __grid_constant__ Shape "; + kernel_source += name; + kernel_source += "_shape,\n"; + } + if (shape_infos[i].strides) { + kernel_source += " const __grid_constant__ Strides "; + kernel_source += name; + kernel_source += "_strides,\n"; + } + if (shape_infos[i].ndim) { + kernel_source += " const __grid_constant__ int "; + kernel_source += name; + kernel_source += "_ndim,\n"; + } + } + } + + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " "; + kernel_source += dtype_to_cuda_type(dtype); + kernel_source += "* "; + kernel_source += name; + if (i < output_names.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + } + + // Set compile time constants + if (!template_args.empty()) { + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + kernel_source += + fmt::format(" constexpr int {} = {};\n", name, std::get(arg)); + } else if (std::holds_alternative(arg)) { + kernel_source += fmt::format( + " constexpr bool {} = {};\n", name, std::get(arg)); + } else { + kernel_source += fmt::format( + " using {} = {};\n", + name, + dtype_to_cuda_type(std::get(arg))); + } + } + kernel_source += "\n"; + } + + kernel_source += source; + kernel_source += "\n}\n\n} // namespace mlx::core::cu\n"; + + return kernel_source; +} + +} // namespace + +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + + return [=, shape_infos = std::move(shape_infos)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } + + std::string kernel_name = + "custom_kernel_" + name + template_arguments_hash(template_args); + std::string kernel_source = build_kernel( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + shape_infos); + + if (verbose) { + std::cout << "Generated source code for `" << kernel_name + << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + std::move(inputs)); + }; +} + +void CustomKernel::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("CustomKernel::eval_gpu"); + auto& s = stream(); + + std::vector copies; + + // Allocate and initialize the output arrays + for (auto& out : outputs) { + if (init_value_) { + copies.emplace_back(init_value_.value(), out.dtype()); + fill_gpu(copies.back(), out, s); + } else { + out.set_data(allocator::malloc(out.nbytes())); + } + } + + // Create the input arrays and copy if needed + auto check_input = [&copies, &s, this](const array& x) -> const array { + bool no_copy = x.flags().row_contiguous; + if (!ensure_row_contiguous_ || no_copy) { + return x; + } else { + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); + } + }; + std::vector checked_inputs; + for (const array& in : inputs) { + checked_inputs.push_back(check_input(in)); + } + + // Compile the custom kernel + std::string kernel_name = "mlx::core::cu::" + name_; + cu::JitModule& mod = cu::get_jit_module(s.device, name_, [&]() { + return std::make_pair(source_, std::vector{kernel_name}); + }); + + // Make the arguments + cu::KernelArgs args; + for (int i = 0; i < checked_inputs.size(); i++) { + const array& in = checked_inputs[i]; + auto& shape_info = shape_infos_[i]; + args.append(in); + if (shape_info.shape) { + args.append_ndim(in.shape()); + } + if (shape_info.strides) { + args.append_ndim(in.strides()); + } + if (shape_info.ndim) { + args.append(in.ndim()); + } + } + for (auto& out : outputs) { + args.append(out); + } + + // Make the grid + const auto [tx, ty, tz] = threadgroup_; + const auto [gx, gy, gz] = grid_; + dim3 block(tx, ty, tz); + dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); + + // Call the kernel + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : checked_inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + auto kernel = mod.get_kernel(kernel_name); + encoder.add_kernel_node(kernel, grid, block, 0, args.args()); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index f5a61366c..aa20f0128 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -41,10 +41,6 @@ NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) -namespace fast { -NO_GPU_MULTI(CustomKernel) -} // namespace fast - namespace distributed { NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllGather) diff --git a/mlx/backend/metal/no_metal.cpp b/mlx/backend/metal/no_metal.cpp index 9785e07c2..85e1080d5 100644 --- a/mlx/backend/metal/no_metal.cpp +++ b/mlx/backend/metal/no_metal.cpp @@ -24,19 +24,4 @@ device_info() { } // namespace metal -namespace fast { - -MetalKernelFunction metal_kernel( - const std::string&, - const std::vector&, - const std::vector&, - const std::string&, - const std::string&, - bool ensure_row_contiguous, - bool atomic_outputs) { - throw std::runtime_error("[metal_kernel] No GPU back-end."); -} - -} // namespace fast - } // namespace mlx::core