Custom cuda kernel (#2517)

This commit is contained in:
Angelos Katharopoulos 2025-08-20 17:20:22 -07:00 committed by GitHub
parent f4c8888cbe
commit e397177f6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1042 additions and 211 deletions

9
docs/src/python/cuda.rst Normal file
View File

@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@ -13,3 +13,4 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@ -20,6 +20,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu

View File

@ -267,7 +267,8 @@ void Compiled::eval_gpu(
} }
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names)); return std::make_tuple(
false, std::move(builder.os), std::move(kernel_names));
}); });
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also

View File

@ -0,0 +1,379 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_gather, std::move(kernel_names)); return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_scatter, std::move(kernel_names)); return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_gather_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;
@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_scatter_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;

View File

@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::string& ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return false; return false;
} }
@ -117,15 +117,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) { if (!ptx_file.good()) {
return false; return false;
} }
ptx->resize(ptx_size); ptx.resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size); ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
if (tab != std::string::npos) { if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
} }
} }
return true; return true;
@ -135,7 +135,7 @@ bool read_cached_ptx(
void write_cached_ptx( void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) { const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
@ -217,22 +217,18 @@ constexpr const char* g_headers[] = {
jit_source_utils, jit_source_utils,
}; };
} // namespace void compile(
JitModule::JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const std::string& source,
// Check cache. const std::vector<std::string>& kernel_names,
std::vector<char> ptx; std::string& ptx,
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { // Create the program
// Create program.
auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram( CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog, &prog,
source_code.c_str(), source.c_str(),
(module_name + ".cu").c_str(), (module_name + ".cu").c_str(),
std::size(g_headers), std::size(g_headers),
g_headers, g_headers,
@ -286,16 +282,20 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size)); CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
} }
ptx.resize(ptx_size, 0); ptx.resize(ptx_size);
if (use_sass) { if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@ -312,21 +312,69 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels[name] = std::make_pair(kernel, false);
} }
} }
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) { if (it == kernels_.end()) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name)); fmt::format("There is no kernel named {}.", kernel_name));
} }
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}
return it->second.first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
@ -337,11 +385,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder) { const KernelBuilder& builder,
bool cache) {
auto& map = get_jit_module_cache(); auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first; it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
} }
return it->second; return it->second;
} }

View File

@ -19,7 +19,8 @@ namespace mlx::core::cu {
class Device; class Device;
using KernelBuilderResult = std::pair< using KernelBuilderResult = std::tuple<
/* precompiled */ bool,
/* source code */ std::string, /* source code */ std::string,
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
@ -63,14 +64,16 @@ struct KernelArgs {
private: private:
std::vector<void*> args_; std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store // The cuGraphAddKernelNode API requires passing pointers to arguments so
// temporary values untill kernel is launched. // store temporary values until the node is created.
using Arg = std::variant< using Arg = std::variant<
std::monostate, std::monostate,
CUdeviceptr, CUdeviceptr,
bool,
int32_t, int32_t,
uint32_t, uint32_t,
int64_t, int64_t,
float,
SmallVector<const void*>, SmallVector<const void*>,
SmallVector<int32_t>, SmallVector<int32_t>,
SmallVector<int64_t>>; SmallVector<int64_t>>;
@ -82,16 +85,19 @@ class JitModule {
JitModule( JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool cache);
~JitModule(); ~JitModule();
JitModule(const JitModule&) = delete; JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name); CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();
@ -99,6 +105,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool use_disk_cache = true);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@ -1,11 +1,47 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core::cu { namespace mlx::core {
namespace cu {
bool is_available() { bool is_available() {
return false; return false;
} }
} // namespace mlx::core::cu } // namespace cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@ -41,10 +41,6 @@ NO_GPU(Cholesky)
NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)

View File

@ -172,7 +172,7 @@ std::string write_template(
return template_def.str(); return template_def.str();
} }
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string& name, const std::string& name,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
threadgroup, threadgroup,
shape_infos, shape_infos,
ensure_row_contiguous, ensure_row_contiguous,
init_value), init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs)); std::move(inputs));
}; };
} }

View File

@ -26,15 +26,15 @@ device_info() {
namespace fast { namespace fast {
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string&, const std::string&,
const std::vector<std::string>&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<std::string>&,
const std::string&, const std::string&,
const std::string&, const std::string&,
bool ensure_row_contiguous, bool,
bool atomic_outputs) { bool) {
throw std::runtime_error("[metal_kernel] No GPU back-end."); throw std::runtime_error("[metal_kernel] No Metal back-end.");
} }
} // namespace fast } // namespace fast

View File

@ -66,9 +66,10 @@ array affine_dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg; using TemplateArg = std::variant<int, bool, Dtype>;
using ScalarArg = std::variant<bool, int, float>;
typedef std::function<std::vector<array>( using CustomKernelFunction = std::function<std::vector<array>(
const std::vector<array>&, const std::vector<array>&,
const std::vector<Shape>&, const std::vector<Shape>&,
const std::vector<Dtype>&, const std::vector<Dtype>&,
@ -77,10 +78,9 @@ typedef std::function<std::vector<array>(
std::vector<std::pair<std::string, TemplateArg>>, std::vector<std::pair<std::string, TemplateArg>>,
std::optional<float>, std::optional<float>,
bool, bool,
StreamOrDevice)> StreamOrDevice)>;
MetalKernelFunction;
MetalKernelFunction metal_kernel( CustomKernelFunction metal_kernel(
const std::string& name, const std::string& name,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
@ -89,4 +89,27 @@ MetalKernelFunction metal_kernel(
bool ensure_row_contiguous = true, bool ensure_row_contiguous = true,
bool atomic_outputs = false); bool atomic_outputs = false);
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header = "",
bool ensure_row_contiguous = true,
int shared_memory = 0);
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory = 0,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
StreamOrDevice s = {});
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <optional> #include <optional>
#include <variant>
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -283,6 +284,8 @@ struct CustomKernelShapeInfo {
bool ndim = false; bool ndim = false;
}; };
using ScalarArg = std::variant<bool, int, float>;
class CustomKernel : public Primitive { class CustomKernel : public Primitive {
public: public:
CustomKernel( CustomKernel(
@ -293,7 +296,10 @@ class CustomKernel : public Primitive {
std::tuple<int, int, int> threadgroup, std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos, std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous, bool ensure_row_contiguous,
std::optional<float> init_value) std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
: Primitive(stream), : Primitive(stream),
source_(std::move(source)), source_(std::move(source)),
name_(std::move(name)), name_(std::move(name)),
@ -301,11 +307,14 @@ class CustomKernel : public Primitive {
threadgroup_(threadgroup), threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)), shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous), ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {} init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("Custom Metal kernels only run on GPU."); throw std::runtime_error("Custom kernels only run on GPU.");
} }
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@ -321,6 +330,9 @@ class CustomKernel : public Primitive {
std::vector<CustomKernelShapeInfo> shape_infos_; std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_; bool ensure_row_contiguous_;
std::optional<float> init_value_; std::optional<float> init_value_;
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
}; };
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -17,6 +17,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

19
python/src/cuda.cpp Normal file
View File

@ -0,0 +1,19 @@
// Copyright © 2023-2025 Apple Inc.
#include <nanobind/nanobind.h>
#include "mlx/backend/cuda/cuda.h"
namespace mx = mlx::core;
namespace nb = nanobind;
void init_cuda(nb::module_& m) {
nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
cuda.def(
"is_available",
&mx::cu::is_available,
R"pbdoc(
Check if the CUDA back-end is available.
)pbdoc");
}

View File

@ -17,6 +17,66 @@ namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
namespace {
struct PyCustomKernelFunction {
PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag)
: kernel_(std::move(kernel)), tag_(tag) {}
std::vector<mx::array> operator()(
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) const {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
std::ostringstream msg;
msg << tag_
<< " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.";
throw std::invalid_argument(msg.str());
}
}
}
return kernel_(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
}
mx::fast::CustomKernelFunction kernel_;
const char* tag_;
};
} // namespace
void init_fast(nb::module_& parent_module) { void init_fast(nb::module_& parent_module) {
auto m = auto m =
parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); parent_module.def_submodule("fast", "mlx.core.fast: fast operations");
@ -240,53 +300,7 @@ void init_fast(nb::module_& parent_module) {
ensure_row_contiguous, ensure_row_contiguous,
atomic_outputs); atomic_outputs);
return nb::cpp_function( return nb::cpp_function(
[kernel = std::move(kernel)]( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"),
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
mx::StreamOrDevice s = {}) {
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, mx::fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<mx::Dtype>(value)) {
mx::Dtype dtype = nb::cast<mx::Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(), nb::kw_only(),
"inputs"_a, "inputs"_a,
"output_shapes"_a, "output_shapes"_a,
@ -384,4 +398,216 @@ void init_fast(nb::module_& parent_module) {
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
)pbdoc"); )pbdoc");
m.def(
"cuda_kernel",
[](const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_mem) {
auto kernel = mx::fast::cuda_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
shared_mem);
return nb::cpp_function(
PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"),
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.)pbdoc");
},
"name"_a,
"input_names"_a,
"output_names"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"shared_memory"_a = 0,
R"pbdoc(
A jit-compiled custom CUDA kernel defined from a source string.
This is the CUDA equivalent of :ref:`custom_metal_kernels`.
Args:
name (str): Name for the kernel.
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
source (str): Source code. This is the body of a function in CUDA,
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
Returns:
Callable ``cuda_kernel``.
Example:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = '''
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp);
'''
kernel = mx.fast.cuda_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
verbose=True,
)
return outputs[0]
a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc");
m.def(
"precompiled_cuda_kernel",
[](const std::string& name,
const nb::bytes compiled_source,
const std::vector<ScalarOrArray>& inputs_,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
const std::vector<nb::object>& scalars_,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value = std::nullopt,
bool ensure_row_contiguous = false,
mx::StreamOrDevice s = {}) {
// Collect the inputs and cast them to array
std::vector<mx::array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
// Collect the scalar inputs
std::vector<mx::fast::ScalarArg> scalars;
scalars.reserve(scalars_.size());
for (const auto& v : scalars_) {
if (nb::isinstance<bool>(v)) {
scalars.push_back(nb::cast<bool>(v));
} else if (nb::isinstance<int>(v)) {
scalars.push_back(nb::cast<int>(v));
} else if (nb::isinstance<float>(v)) {
scalars.push_back(nb::cast<float>(v));
} else {
nb::object vtype = v.attr("__class__");
std::string vtype_name =
nb::cast<std::string>(vtype.attr("__name__"));
std::ostringstream msg;
msg << "[precompiled_cuda_kernel] Invalid scalar argument type. "
<< "Received " << vtype_name
<< " but must be one of bool, int or float";
throw std::invalid_argument(msg.str());
}
}
return mx::fast::precompiled_cuda_kernel(
name,
std::string(
static_cast<const char*>(compiled_source.data()),
compiled_source.size()),
inputs,
output_shapes,
output_dtypes,
scalars,
grid,
threadgroup,
shared_memory,
init_value,
ensure_row_contiguous,
s);
},
nb::kw_only(),
"name"_a,
"compiled_source"_a,
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"scalars"_a,
"grid"_a,
"threadgroup"_a,
"shared_memory"_a = 0,
"init_value"_a = nb::none(),
"ensure_row_contiguous"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Run a precompiled CUDA kernel defined from PTX or cubin.
This op is still experimental and various parts of the API may change.
Args:
name (str): Name for the kernel
compiled_source (bytes): The precompiled kernel in raw bytes.
inputs (List[array]): The inputs passed to the CUDA kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output.
output_dtypes (List[Dtype]): The list of data types for each output.
scalars (List[Union[bool, int, float]]): A list of scalar arguments to
pass to the kernel.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
shared_memory (int): The dynamic shared memory to request for the
kernel. A value of 0 means no dynamic shared memory. Default: ``0``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
)pbdoc");
} }

View File

@ -12,6 +12,7 @@ void init_array(nb::module_&);
void init_device(nb::module_&); void init_device(nb::module_&);
void init_stream(nb::module_&); void init_stream(nb::module_&);
void init_metal(nb::module_&); void init_metal(nb::module_&);
void init_cuda(nb::module_&);
void init_memory(nb::module_&); void init_memory(nb::module_&);
void init_ops(nb::module_&); void init_ops(nb::module_&);
void init_transforms(nb::module_&); void init_transforms(nb::module_&);
@ -35,6 +36,7 @@ NB_MODULE(core, m) {
init_stream(m); init_stream(m);
init_array(m); init_array(m);
init_metal(m); init_metal(m);
init_cuda(m);
init_memory(m); init_memory(m);
init_ops(m); init_ops(m);
init_transforms(m); init_transforms(m);

View File

@ -581,18 +581,28 @@ class TestFast(mlx_tests.MLXTestCase):
)(x) )(x)
self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_basic(self): def test_custom_kernel_basic(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = a[elem]; out1[elem] = a[elem];
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = a[elem];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="basic",
input_names=["a"],
output_names=["out1"],
source=source,
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@ -604,16 +614,9 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], a)) self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_args(self): def test_custom_kernel_args(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = mx.fast.metal_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = a[0]; T tmp = a[0];
@ -623,7 +626,30 @@ class TestFast(mlx_tests.MLXTestCase):
out1[elem] = 1; out1[elem] = 1;
} }
out2[elem] = a[1] + b[2] + c[1] - d; out2[elem] = a[1] + b[2] + c[1] - d;
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = a[0];
if (e) {
out1[elem] = a[1] + b[2] + static_cast<float>(c[3]) + d[0] + f;
} else {
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + static_cast<float>(c[1]) - d[0];
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = custom_kernel(
name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source=source,
) )
out = kernel( out = kernel(
inputs=[ inputs=[
@ -647,10 +673,9 @@ class TestFast(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484)))
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_strides(self): def test_custom_kernel_strides(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(3, 6))
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
@ -662,12 +687,29 @@ class TestFast(mlx_tests.MLXTestCase):
T tmp = inp[elem]; T tmp = inp[elem];
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup; out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
""" """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
auto loc = elem_to_loc(elem, inp_shape.data(), inp_strides.data(), inp_ndim);
T tmp = inp[loc];
out[elem] = exp(tmp) * WARP_SIZE;
"""
source_contig = """
auto elem = cooperative_groups::this_grid().thread_rank();
T tmp = inp[elem];
out[elem] = exp(tmp) * WARP_SIZE;
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
# non contiguous # non contiguous
a = mx.tile(a[::2], [4, 1]) a = mx.tile(a[::2], [4, 1])
for contig in [True, False]: for contig in [True, False]:
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="myexp" + str(contig), name="myexp" + str(contig),
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
@ -685,24 +727,41 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_helper(self): def test_custom_kernel_helper(self):
mx.random.seed(7) if mx.metal.is_available():
a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header = """ header = """
template <typename T> template <typename T>
T do_exp(T x) { T do_exp(T x) {
return metal::precise::exp(x); return metal::precise::exp(x);
} }
""", """
source = """ source = """
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = do_exp(a[elem]); out1[elem] = do_exp(a[elem]);
""", """
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
header = """
template <typename T>
__device__ T do_exp(T x) {
return exp(x);
}
"""
source = """
auto elem = cooperative_groups::this_grid().thread_rank();
out1[elem] = do_exp(a[elem]);
"""
custom_kernel = mx.fast.cuda_kernel
mx.random.seed(7)
a = mx.random.normal(shape=(2, 2))
kernel = custom_kernel(
name="helper",
input_names=["a"],
output_names=["out1"],
header=header,
source=source,
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],
@ -714,16 +773,21 @@ class TestFast(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(out[0], mx.exp(a))) self.assertTrue(mx.allclose(out[0], mx.exp(a)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.is_available(mx.gpu), "No GPU available")
def test_custom_kernel_attributes(self): def test_custom_kernel_attributes(self):
if mx.metal.is_available():
source = "out[0] = threads_per_threadgroup.x;"
custom_kernel = mx.fast.metal_kernel
elif mx.cuda.is_available():
source = "out[0] = blockDim.x;"
custom_kernel = mx.fast.cuda_kernel
a = mx.zeros(shape=(1, 1)) a = mx.zeros(shape=(1, 1))
kernel = mx.fast.metal_kernel( kernel = custom_kernel(
name="test_fun", name="test_fun",
input_names=["a"], input_names=["a"],
output_names=["out"], output_names=["out"],
source=""" source=source,
out[0] = threads_per_threadgroup.x;
""",
) )
out = kernel( out = kernel(
inputs=[a], inputs=[a],