mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 08:41:13 +08:00
Some overhead reductions in mx.fast.metal_kernel (#1437)
* some overhead reductions * fix * use += * use more +=
This commit is contained in:
parent
4f9f9ebb6f
commit
0b4a58699e
@ -49,7 +49,7 @@ void CustomKernel::eval_gpu(
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto shape_info = shape_infos_[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
@ -68,7 +68,7 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
}
|
||||
}
|
||||
for (array out : outputs) {
|
||||
for (auto& out : outputs) {
|
||||
compute_encoder.set_output_array(out, index);
|
||||
index++;
|
||||
}
|
||||
|
221
mlx/fast.cpp
221
mlx/fast.cpp
@ -916,21 +916,25 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
void write_signature(
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source << "template <";
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
@ -942,44 +946,18 @@ void write_signature(
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source << ", ";
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source << param_type << " " << name;
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source << ">" << std::endl;
|
||||
}
|
||||
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
|
||||
|
||||
// Metal attributes are automatically added to the arguments if present
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
std::vector<std::pair<std::string, std::string>> attrs;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attrs.push_back({attr, dtype});
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
@ -988,69 +966,82 @@ void write_signature(
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
bool is_constant =
|
||||
arr.is_available() && arr.size() < max_constant_array_size;
|
||||
std::string location = is_constant ? "constant" : "device";
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source << " const " << location << " " << dtype << ref << " "
|
||||
<< name << " [[buffer(" << index << ")]]," << std::endl;
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
CustomKernelShapeInfo shape_info;
|
||||
if (arr.ndim() > 0) {
|
||||
if (source.find(name + "_shape") != std::string::npos) {
|
||||
kernel_source << " const constant int* " << name << "_shape [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.shape = true;
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_strides") != std::string::npos) {
|
||||
kernel_source << " const constant size_t* " << name
|
||||
<< "_strides [[buffer(" << index << ")]]," << std::endl;
|
||||
shape_info.strides = true;
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant size_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_ndim") != std::string::npos) {
|
||||
kernel_source << " const constant int& " << name << "_ndim [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.ndim = true;
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source << " device ";
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source << "atomic<" << type_string << ">";
|
||||
} else {
|
||||
kernel_source << type_string;
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
// Add metal attributes e.g. `threadgroup_index_in_grid`
|
||||
|
||||
index = 0;
|
||||
for (const auto& [attr, dtype] : attrs) {
|
||||
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
|
||||
if (index < attrs.size() - 1) {
|
||||
kernel_source << "," << std::endl;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source << source << std::endl;
|
||||
kernel_source << "}" << std::endl;
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
@ -1087,8 +1078,48 @@ MetalKernelFunction metal_kernel(
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
return [=](const std::vector<array>& inputs,
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
@ -1126,7 +1157,6 @@ MetalKernelFunction metal_kernel(
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::string template_def = "";
|
||||
std::string hash_key = "";
|
||||
if (!template_args.empty()) {
|
||||
@ -1135,54 +1165,53 @@ MetalKernelFunction metal_kernel(
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
|
||||
func_name << "custom_kernel_" << name << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header << std::endl;
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs,
|
||||
kernel_source);
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
inputs);
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -262,11 +262,11 @@ class CustomKernel : public Primitive {
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value)
|
||||
: Primitive(stream),
|
||||
source_(source),
|
||||
name_(name),
|
||||
source_(std::move(source)),
|
||||
name_(std::move(name)),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(shape_infos),
|
||||
shape_infos_(std::move(shape_infos)),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value) {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user