mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 12:13:21 +08:00
alternative solution
This commit is contained in:
parent
ac1117b224
commit
6741d15735
@ -397,11 +397,11 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -172,11 +172,11 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
|
@ -1,12 +1,326 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
struct CustomKernelCache {
|
||||
std::unordered_map<std::string, std::string> libraries;
|
||||
};
|
||||
|
||||
static CustomKernelCache& cache() {
|
||||
static CustomKernelCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name = "custom_kernel_" + name;
|
||||
std::string template_def = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
auto template_hash =
|
||||
std::regex_replace(template_def, disallowed_chars, "_");
|
||||
template_hash.pop_back();
|
||||
kernel_name += "_";
|
||||
kernel_name += template_hash;
|
||||
}
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib =
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
|
||||
{
|
||||
// Clear kernels from the device library cache if needed
|
||||
auto& kernel_cache = cache();
|
||||
if (auto it = kernel_cache.libraries.find(name_);
|
||||
it != kernel_cache.libraries.end()) {
|
||||
if (it->second != source_) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.clear_library(name_);
|
||||
it->second = source_;
|
||||
}
|
||||
} else {
|
||||
kernel_cache.libraries.emplace(name_, source_);
|
||||
}
|
||||
}
|
||||
|
||||
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
@ -326,11 +326,11 @@ Device::Device() {
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
for (auto& [l, kernel_map] : library_kernels_) {
|
||||
l->release();
|
||||
for (auto& [_, k] : kernel_map) {
|
||||
k->release();
|
||||
}
|
||||
}
|
||||
stream_map_.clear();
|
||||
device_->release();
|
||||
@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& path /* = "" */) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto new_lib = load_library(device_, name, path.c_str());
|
||||
library_map_.insert({name, new_lib});
|
||||
return new_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
void Device::clear_library(const std::string& name) {
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
auto kernel_map_it = library_kernels_.find(it->second);
|
||||
for (auto& [_, kernel] : kernel_map_it->second) {
|
||||
kernel->release();
|
||||
}
|
||||
library_kernels_.erase(kernel_map_it);
|
||||
it->second->release();
|
||||
library_map_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
return get_kernel(
|
||||
base_name, default_library_, hash_name, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
|
@ -187,14 +187,16 @@ class Device {
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path = "");
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::function<std::string(void)>& builder);
|
||||
|
||||
void clear_library(const std::string& name);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
@ -204,7 +206,6 @@ class Device {
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
@ -258,10 +259,13 @@ class Device {
|
||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||
|
||||
std::shared_mutex kernel_mtx_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
MTL::Library* default_library_;
|
||||
std::unordered_map<
|
||||
MTL::Library*,
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int max_ops_per_buffer_;
|
||||
|
@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
@ -387,7 +387,7 @@ void LayerNormVJP::eval_gpu(
|
||||
};
|
||||
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
|
@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
@ -180,7 +180,7 @@ void sdpa_vector(
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@ -281,7 +281,7 @@ void sdpa_vector_2pass(
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
308
mlx/fast.cpp
308
mlx/fast.cpp
@ -1,11 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@ -1030,308 +1026,4 @@ std::vector<Shape> AffineQuantize::output_shapes(
|
||||
}
|
||||
}
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
auto now = std::chrono::system_clock::now();
|
||||
int64_t timestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch())
|
||||
.count();
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
std::string template_def = "";
|
||||
std::string template_hash = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
template_hash = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
template_hash.pop_back();
|
||||
}
|
||||
func_name << "custom_kernel_" << name << "_" << template_hash << "_"
|
||||
<< timestamp;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
Loading…
Reference in New Issue
Block a user