mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Custom Metal Kernels from Python (#1325)
* start * simple kernels working * restructure * inverse example working * docs + fixes * missing file * fix imports * address comments * add docs + fix test * Review comments + refactor to a single function * update docs * remove hashing * fix contig bug in test * back to a class * trailing whitespace * fix tests * match c++ and python apis * add link + make args kw_only
This commit is contained in:
@@ -131,6 +131,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
84
mlx/backend/metal/custom_kernel.cpp
Normal file
84
mlx/backend/metal/custom_kernel.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<const array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
||||
}
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (array out : outputs) {
|
||||
compute_encoder.set_output_array(out, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
Reference in New Issue
Block a user