mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Custom cuda kernel (#2517)
This commit is contained in:

committed by
GitHub

parent
f4c8888cbe
commit
e397177f6e
@@ -172,7 +172,7 @@ std::string write_template(
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
CustomKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
@@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
0),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
@@ -26,15 +26,15 @@ device_info() {
|
||||
|
||||
namespace fast {
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
CustomKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
bool,
|
||||
bool) {
|
||||
throw std::runtime_error("[metal_kernel] No Metal back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
Reference in New Issue
Block a user