Add grid_sample example to metal_kernel docs (#1352)

* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
This commit is contained in:
Alex Barron
2024-08-23 18:24:16 -07:00
committed by GitHub
parent 3b4d5484c7
commit b96e105244
6 changed files with 337 additions and 15 deletions

View File

@@ -949,6 +949,7 @@ void write_signature(
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
@@ -1042,8 +1043,14 @@ void write_signature(
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
kernel_source << " device " << get_type_string(dtype) << "* " << name
<< " [[buffer(" << index << ")]]";
kernel_source << " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source << "atomic<" << type_string << ">";
} else {
kernel_source << type_string;
}
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
@@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
@@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
output_dtypes,
template_args,
shape_infos,
atomic_outputs_,
kernel_source);
if (needs_template) {
@@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_),
ensure_row_contiguous_,
init_value),
in_arrs);
int i = 0;