Add grid_sample example to metal_kernel docs (#1352)

* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
This commit is contained in:
Alex Barron
2024-08-23 18:24:16 -07:00
committed by GitHub
parent 3b4d5484c7
commit b96e105244
6 changed files with 337 additions and 15 deletions

View File

@@ -12,12 +12,17 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
auto& s = stream();
std::vector<array> copies;
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
array init = array(init_value_.value(), out.dtype());
copy_gpu(init, out, CopyType::Scalar, s);
copies.push_back(init);
}
}
std::vector<array> copies;
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {

View File

@@ -949,6 +949,7 @@ void write_signature(
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
@@ -1042,8 +1043,14 @@ void write_signature(
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
kernel_source << " device " << get_type_string(dtype) << "* " << name
<< " [[buffer(" << index << ")]]";
kernel_source << " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source << "atomic<" << type_string << ">";
} else {
kernel_source << type_string;
}
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
@@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
@@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
output_dtypes,
template_args,
shape_infos,
atomic_outputs_,
kernel_source);
if (needs_template) {
@@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_),
ensure_row_contiguous_,
init_value),
in_arrs);
int i = 0;

View File

@@ -71,10 +71,12 @@ class MetalKernel {
MetalKernel(
const std::string& name,
const std::string& source,
bool ensure_row_contiguous)
bool ensure_row_contiguous,
bool atomic_outputs)
: name_(name),
source_(source),
ensure_row_contiguous_(ensure_row_contiguous) {}
ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {}
std::map<std::string, array> operator()(
std::map<std::string, array>& inputs,
@@ -84,6 +86,7 @@ class MetalKernel {
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args =
std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
@@ -91,5 +94,6 @@ class MetalKernel {
std::string name_;
std::string source_;
bool ensure_row_contiguous_ = true;
bool atomic_outputs_ = false;
};
} // namespace mlx::core::fast

View File

@@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::fast {
@@ -257,14 +259,16 @@ class CustomKernel : public Primitive {
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous)
bool ensure_row_contiguous,
std::optional<float> init_value)
: Primitive(stream),
source_(source),
name_(name),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(shape_infos),
ensure_row_contiguous_(ensure_row_contiguous) {}
ensure_row_contiguous_(ensure_row_contiguous),
init_value_(init_value) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -283,6 +287,7 @@ class CustomKernel : public Primitive {
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
};
} // namespace mlx::core::fast