mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Working custom kernels jointly
This commit is contained in:
parent
0b309e8edc
commit
3b94e37270
@ -234,11 +234,47 @@ MetalKernelFunction metal_kernel(
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
0),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_custom_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<ScalarArg>& scalars,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
to_stream(s),
|
||||
name,
|
||||
compiled_source,
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
scalars,
|
||||
true,
|
||||
shared_memory),
|
||||
inputs);
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@ -274,10 +310,19 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
// Compile the custom kernel
|
||||
std::string kernel_name = "mlx::core::cu::" + name_;
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, name_, [&]() {
|
||||
return std::make_pair(source_, std::vector<std::string>{kernel_name});
|
||||
});
|
||||
std::string kernel_name =
|
||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
||||
auto get_module = [&]() -> cu::JitModule& {
|
||||
if (is_precompiled_) {
|
||||
return cu::get_jit_module(
|
||||
s.device, name_, source_, std::vector<std::string>{kernel_name});
|
||||
} else {
|
||||
return cu::get_jit_module(s.device, name_, [&]() {
|
||||
return std::make_pair(source_, std::vector<std::string>{kernel_name});
|
||||
});
|
||||
}
|
||||
};
|
||||
cu::JitModule& mod = get_module();
|
||||
|
||||
// Make the arguments
|
||||
cu::KernelArgs args;
|
||||
@ -298,6 +343,15 @@ void CustomKernel::eval_gpu(
|
||||
for (auto& out : outputs) {
|
||||
args.append(out);
|
||||
}
|
||||
for (auto& s : scalar_arguments_) {
|
||||
if (std::holds_alternative<bool>(s)) {
|
||||
args.append(std::get<bool>(s));
|
||||
} else if (std::holds_alternative<int>(s)) {
|
||||
args.append(std::get<int>(s));
|
||||
} else if (std::holds_alternative<float>(s)) {
|
||||
args.append(std::get<float>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Make the grid
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
@ -313,8 +367,17 @@ void CustomKernel::eval_gpu(
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
for (const auto& t : copies) {
|
||||
encoder.add_temporary(t);
|
||||
}
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
encoder.add_kernel_node(kernel, grid, block, 0, args.args());
|
||||
if (shared_memory_ > 0 && shared_memory_ > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
shared_memory_);
|
||||
}
|
||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include <future>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
@ -249,9 +251,201 @@ void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void debugCuGraphAddKernelNode(
|
||||
CUgraphNode* node,
|
||||
cudaGraph_t graph,
|
||||
const CUDA_KERNEL_NODE_PARAMS* params) {
|
||||
std::cout << "=== Debugging cuGraphAddKernelNode ===" << std::endl;
|
||||
|
||||
// Check graph
|
||||
if (graph == nullptr) {
|
||||
std::cout << "ERROR: graph is NULL" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check params structure
|
||||
if (params == nullptr) {
|
||||
std::cout << "ERROR: params is NULL" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check kernel function
|
||||
if (params->func == nullptr) {
|
||||
std::cout << "ERROR: kernel function (CUfunction) is NULL" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Validate kernel function and get attributes
|
||||
int maxThreadsPerBlock;
|
||||
CUresult funcErr = cuFuncGetAttribute(
|
||||
&maxThreadsPerBlock,
|
||||
CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
|
||||
params->func);
|
||||
if (funcErr != CUDA_SUCCESS) {
|
||||
const char* errStr;
|
||||
cuGetErrorString(funcErr, &errStr);
|
||||
std::cout << "ERROR: Invalid kernel function - " << errStr << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Get more function attributes
|
||||
int sharedSize, constSize, localSize, numRegs;
|
||||
cuFuncGetAttribute(
|
||||
&sharedSize, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, params->func);
|
||||
cuFuncGetAttribute(
|
||||
&constSize, CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, params->func);
|
||||
cuFuncGetAttribute(
|
||||
&localSize, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, params->func);
|
||||
cuFuncGetAttribute(&numRegs, CU_FUNC_ATTRIBUTE_NUM_REGS, params->func);
|
||||
|
||||
std::cout << "Kernel function attributes:" << std::endl;
|
||||
std::cout << " Max threads per block: " << maxThreadsPerBlock << std::endl;
|
||||
std::cout << " Shared memory: " << sharedSize << " bytes" << std::endl;
|
||||
std::cout << " Const memory: " << constSize << " bytes" << std::endl;
|
||||
std::cout << " Local memory: " << localSize << " bytes" << std::endl;
|
||||
std::cout << " Num regs: " << numRegs << std::endl;
|
||||
|
||||
// Check dimensions
|
||||
std::cout << "\nGrid dimensions: (" << params->gridDimX << ", "
|
||||
<< params->gridDimY << ", " << params->gridDimZ << ")" << std::endl;
|
||||
|
||||
std::cout << "Block dimensions: (" << params->blockDimX << ", "
|
||||
<< params->blockDimY << ", " << params->blockDimZ << ")"
|
||||
<< std::endl;
|
||||
|
||||
if (params->gridDimX * params->gridDimY * params->gridDimZ == 0) {
|
||||
std::cout << "ERROR: Grid dimension contains zero!" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->blockDimX * params->blockDimY * params->blockDimZ == 0) {
|
||||
std::cout << "ERROR: Block dimension contains zero!" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// Get current device and check limits
|
||||
CUdevice device;
|
||||
cuCtxGetDevice(&device);
|
||||
|
||||
int maxGridX, maxGridY, maxGridZ;
|
||||
int maxBlockX, maxBlockY, maxBlockZ;
|
||||
int maxThreadsPerBlockDevice;
|
||||
int maxSharedMemPerBlock;
|
||||
|
||||
cuDeviceGetAttribute(&maxGridX, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
|
||||
cuDeviceGetAttribute(&maxGridY, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
|
||||
cuDeviceGetAttribute(&maxGridZ, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
|
||||
cuDeviceGetAttribute(&maxBlockX, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device);
|
||||
cuDeviceGetAttribute(&maxBlockY, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device);
|
||||
cuDeviceGetAttribute(&maxBlockZ, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device);
|
||||
cuDeviceGetAttribute(
|
||||
&maxThreadsPerBlockDevice,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
|
||||
device);
|
||||
cuDeviceGetAttribute(
|
||||
&maxSharedMemPerBlock,
|
||||
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
|
||||
device);
|
||||
|
||||
std::cout << "\nDevice limits check:" << std::endl;
|
||||
std::cout << " Max grid size: (" << maxGridX << ", " << maxGridY << ", "
|
||||
<< maxGridZ << ")" << std::endl;
|
||||
|
||||
std::cout << " Max block size: (" << maxBlockX << ", " << maxBlockY << ", "
|
||||
<< maxBlockZ << ")" << std::endl;
|
||||
|
||||
std::cout << " Max threads per block: " << maxThreadsPerBlockDevice
|
||||
<< std::endl;
|
||||
|
||||
// Check if dimensions exceed limits
|
||||
if (params->gridDimX > (unsigned)maxGridX ||
|
||||
params->gridDimY > (unsigned)maxGridY ||
|
||||
params->gridDimZ > (unsigned)maxGridZ) {
|
||||
std::cout << "ERROR: Grid dimensions exceed device limits!" << std::endl;
|
||||
}
|
||||
|
||||
if (params->blockDimX > (unsigned)maxBlockX ||
|
||||
params->blockDimY > (unsigned)maxBlockY ||
|
||||
params->blockDimZ > (unsigned)maxBlockZ) {
|
||||
std::cout << "ERROR: Block dimensions exceed device limits!" << std::endl;
|
||||
}
|
||||
|
||||
unsigned totalThreadsPerBlock =
|
||||
params->blockDimX * params->blockDimY * params->blockDimZ;
|
||||
if (totalThreadsPerBlock > (unsigned)maxThreadsPerBlockDevice) {
|
||||
std::cout << "ERROR: Total threads per block (" << totalThreadsPerBlock
|
||||
<< ") exceeds limit (" << maxThreadsPerBlockDevice << ")"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Check shared memory
|
||||
std::cout << "\nShared memory requested: " << params->sharedMemBytes
|
||||
<< " bytes" << std::endl;
|
||||
std::cout << "Max shared memory per block: " << maxSharedMemPerBlock
|
||||
<< " bytes" << std::endl;
|
||||
|
||||
if (params->sharedMemBytes > (unsigned)maxSharedMemPerBlock) {
|
||||
std::cout << "ERROR: Requested shared memory exceeds limit!" << std::endl;
|
||||
}
|
||||
|
||||
// Check kernel parameters
|
||||
std::cout << "\nKernel parameters:" << std::endl;
|
||||
std::cout << " kernelParams pointer: " << std::hex << params->kernelParams
|
||||
<< std::dec << std::endl;
|
||||
std::cout << " extra pointer: " << std::hex << params->extra << std::dec
|
||||
<< std::endl;
|
||||
|
||||
if (params->kernelParams == nullptr && params->extra == nullptr) {
|
||||
std::cout
|
||||
<< "WARNING: Both kernelParams and extra are NULL (no arguments to kernel)"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// If using kernelParams, try to print the array
|
||||
if (params->kernelParams != nullptr) {
|
||||
std::cout << " Kernel parameter pointers:" << std::endl;
|
||||
for (int i = 0; i < 10; i++) { // Check first 10 slots (adjust as needed)
|
||||
if (params->kernelParams[i] == nullptr) {
|
||||
std::cout << " [" << i << "]: NULL (end of params)" << std::endl;
|
||||
break;
|
||||
}
|
||||
std::cout << " [" << i << "]: " << std::hex << params->kernelParams[i]
|
||||
<< std::dec << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to add the node
|
||||
std::cout << "\nAttempting to add kernel node..." << std::endl;
|
||||
CUresult err = cuGraphAddKernelNode(node, graph, NULL, 0, params);
|
||||
|
||||
if (err != CUDA_SUCCESS) {
|
||||
const char* errStr;
|
||||
cuGetErrorString(err, &errStr);
|
||||
std::cout << "ERROR: " << errStr << " (code: " << err << ")" << std::endl;
|
||||
|
||||
// Additional hints based on error code
|
||||
if (err == CUDA_ERROR_INVALID_VALUE) {
|
||||
std::cout << "\nHints for 'invalid argument':" << std::endl;
|
||||
std::cout
|
||||
<< " - Check if CUDA_KERNEL_NODE_PARAMS struct is properly initialized"
|
||||
<< std::endl;
|
||||
std::cout << " - Verify CUfunction handle is valid" << std::endl;
|
||||
std::cout << " - Ensure grid/block dimensions are non-zero" << std::endl;
|
||||
std::cout << " - Check kernelParams array is properly set up"
|
||||
<< std::endl;
|
||||
std::cout << " - Verify context is current" << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cout << "SUCCESS: Kernel node added to graph!" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "=== End Debug ===" << std::endl;
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
// CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
debugCuGraphAddKernelNode(&node, graph_, ¶ms);
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
|
@ -316,6 +316,31 @@ JitModule::JitModule(
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
|
||||
void* values[] = {jit_log, reinterpret_cast<void*>(std::size(jit_log) - 1)};
|
||||
CUresult jit_result = cuModuleLoadDataEx(
|
||||
&module_, ptx.c_str(), std::size(options), options, values);
|
||||
if (jit_result != CUDA_SUCCESS) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to load compiled {} kernel: {}.", module_name, jit_log));
|
||||
}
|
||||
|
||||
// Load kernels.
|
||||
for (const auto& name : kernel_names) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, name.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
@ -346,4 +371,18 @@ JitModule& get_jit_module(
|
||||
return it->second;
|
||||
}
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, ptx, kernel_names)
|
||||
.first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -68,9 +68,11 @@ struct KernelArgs {
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
bool,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
float,
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
@ -83,6 +85,11 @@ class JitModule {
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
@ -101,4 +108,10 @@ JitModule& get_jit_module(
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::string>& kernel_names);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
0),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
22
mlx/fast.h
22
mlx/fast.h
@ -66,9 +66,10 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
using TemplateArg = std::variant<int, bool, Dtype>;
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
typedef std::function<std::vector<array>(
|
||||
using MetalKernelFunction = std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
@ -77,8 +78,7 @@ typedef std::function<std::vector<array>(
|
||||
std::vector<std::pair<std::string, TemplateArg>>,
|
||||
std::optional<float>,
|
||||
bool,
|
||||
StreamOrDevice)>
|
||||
MetalKernelFunction;
|
||||
StreamOrDevice)>;
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
@ -89,4 +89,18 @@ MetalKernelFunction metal_kernel(
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false);
|
||||
|
||||
std::vector<array> precompiled_custom_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<ScalarArg>& scalars,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory = 0,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool ensure_row_contiguous = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -283,6 +284,8 @@ struct CustomKernelShapeInfo {
|
||||
bool ndim = false;
|
||||
};
|
||||
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
class CustomKernel : public Primitive {
|
||||
public:
|
||||
CustomKernel(
|
||||
@ -293,7 +296,10 @@ class CustomKernel : public Primitive {
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value)
|
||||
std::optional<float> init_value,
|
||||
std::vector<ScalarArg> scalar_arguments,
|
||||
bool is_precompiled,
|
||||
int shared_memory)
|
||||
: Primitive(stream),
|
||||
source_(std::move(source)),
|
||||
name_(std::move(name)),
|
||||
@ -301,11 +307,14 @@ class CustomKernel : public Primitive {
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(std::move(shape_infos)),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value) {}
|
||||
init_value_(init_value),
|
||||
scalar_arguments_(scalar_arguments),
|
||||
is_precompiled_(is_precompiled),
|
||||
shared_memory_(shared_memory) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("Custom Metal kernels only run on GPU.");
|
||||
throw std::runtime_error("Custom kernels only run on GPU.");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@ -321,6 +330,9 @@ class CustomKernel : public Primitive {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
std::vector<ScalarArg> scalar_arguments_;
|
||||
bool is_precompiled_;
|
||||
int shared_memory_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -384,4 +384,76 @@ void init_fast(nb::module_& parent_module) {
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"precompiled_custom_kernel",
|
||||
[](const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
const std::vector<nb::object>& scalars_,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool ensure_row_contiguous = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
// Collect the inputs and cast them to array
|
||||
std::vector<mx::array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
|
||||
// Collect the scalar inputs
|
||||
std::vector<mx::fast::ScalarArg> scalars;
|
||||
scalars.reserve(scalars_.size());
|
||||
for (const auto& v : scalars_) {
|
||||
if (nb::isinstance<bool>(v)) {
|
||||
scalars.push_back(nb::cast<bool>(v));
|
||||
} else if (nb::isinstance<int>(v)) {
|
||||
scalars.push_back(nb::cast<int>(v));
|
||||
} else if (nb::isinstance<float>(v)) {
|
||||
scalars.push_back(nb::cast<float>(v));
|
||||
} else {
|
||||
nb::object vtype = v.attr("__class__");
|
||||
std::string vtype_name =
|
||||
nb::cast<std::string>(vtype.attr("__name__"));
|
||||
std::ostringstream msg;
|
||||
msg << "[precompiled_custom_kernel] Invalid scalar argument type. "
|
||||
<< "Received " << vtype_name
|
||||
<< " but must be one of bool, int or float";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return mx::fast::precompiled_custom_kernel(
|
||||
name,
|
||||
compiled_source,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
scalars,
|
||||
grid,
|
||||
threadgroup,
|
||||
shared_memory,
|
||||
init_value,
|
||||
ensure_row_contiguous,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"name"_a,
|
||||
"compiled_source"_a,
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"scalars"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"shared_memory"_a = 0,
|
||||
"init_value"_a = nb::none(),
|
||||
"ensure_row_contiguous"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
)pbdoc");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user