Some overhead reductions in mx.fast.metal_kernel (#1437)

* some overhead reductions

* fix

* use +=

* use more +=
This commit is contained in:
Awni Hannun 2024-09-25 17:25:21 -07:00 committed by GitHub
parent 4f9f9ebb6f
commit 0b4a58699e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 101 deletions

View File

@ -49,7 +49,7 @@ void CustomKernel::eval_gpu(
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i];
auto& shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
@ -68,7 +68,7 @@ void CustomKernel::eval_gpu(
}
}
}
for (array out : outputs) {
for (auto& out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}

View File

@ -916,21 +916,25 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
void write_signature(
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs,
std::ostringstream& kernel_source) {
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source << "template <";
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
@ -942,44 +946,18 @@ void write_signature(
param_type = "typename";
}
if (i > 0) {
kernel_source << ", ";
kernel_source += ", ";
}
kernel_source << param_type << " " << name;
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source << ">" << std::endl;
}
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
// Metal attributes are automatically added to the arguments if present
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::pair<std::string, std::string>> attrs;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attrs.push_back({attr, dtype});
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
@ -988,69 +966,82 @@ void write_signature(
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size;
std::string location = is_constant ? "constant" : "device";
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source << " const " << location << " " << dtype << ref << " "
<< name << " [[buffer(" << index << ")]]," << std::endl;
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
CustomKernelShapeInfo shape_info;
if (arr.ndim() > 0) {
if (source.find(name + "_shape") != std::string::npos) {
kernel_source << " const constant int* " << name << "_shape [[buffer("
<< index << ")]]," << std::endl;
shape_info.shape = true;
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (source.find(name + "_strides") != std::string::npos) {
kernel_source << " const constant size_t* " << name
<< "_strides [[buffer(" << index << ")]]," << std::endl;
shape_info.strides = true;
if (shape_infos[i].strides) {
kernel_source +=
(" const constant size_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (source.find(name + "_ndim") != std::string::npos) {
kernel_source << " const constant int& " << name << "_ndim [[buffer("
<< index << ")]]," << std::endl;
shape_info.ndim = true;
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
shape_infos.push_back(shape_info);
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source << " device ";
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source << "atomic<" << type_string << ">";
} else {
kernel_source << type_string;
kernel_source += "atomic<";
}
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source << ") {" << std::endl;
kernel_source += ") {\n";
}
index++;
}
// Add metal attributes e.g. `threadgroup_index_in_grid`
index = 0;
for (const auto& [attr, dtype] : attrs) {
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
if (index < attrs.size() - 1) {
kernel_source << "," << std::endl;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source << ") {" << std::endl;
kernel_source += ") {\n";
}
index++;
}
kernel_source << source << std::endl;
kernel_source << "}" << std::endl;
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
@ -1087,8 +1078,48 @@ MetalKernelFunction metal_kernel(
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
return [=](const std::vector<array>& inputs,
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
@ -1126,7 +1157,6 @@ MetalKernelFunction metal_kernel(
}
std::ostringstream func_name;
std::string template_def = "";
std::string hash_key = "";
if (!template_args.empty()) {
@ -1135,54 +1165,53 @@ MetalKernelFunction metal_kernel(
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name << hash_key;
std::string kernel_name = func_name.str();
std::ostringstream kernel_source;
kernel_source << header << std::endl;
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs,
kernel_source);
atomic_outputs);
if (!template_args.empty()) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
output_shapes,
output_dtypes,
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
inputs);
std::move(inputs));
};
}

View File

@ -262,11 +262,11 @@ class CustomKernel : public Primitive {
bool ensure_row_contiguous,
std::optional<float> init_value)
: Primitive(stream),
source_(source),
name_(name),
source_(std::move(source)),
name_(std::move(name)),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(shape_infos),
shape_infos_(std::move(shape_infos)),
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {}