mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add grid_sample
example to metal_kernel
docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
This commit is contained in:
@@ -12,12 +12,17 @@ void CustomKernel::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
array init = array(init_value_.value(), out.dtype());
|
||||
copy_gpu(init, out, CopyType::Scalar, s);
|
||||
copies.push_back(init);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
|
16
mlx/fast.cpp
16
mlx/fast.cpp
@@ -949,6 +949,7 @@ void write_signature(
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
@@ -1042,8 +1043,14 @@ void write_signature(
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
kernel_source << " device " << get_type_string(dtype) << "* " << name
|
||||
<< " [[buffer(" << index << ")]]";
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source << "atomic<" << type_string << ">";
|
||||
} else {
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
@@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
@@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
|
||||
if (needs_template) {
|
||||
@@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_),
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
|
@@ -71,10 +71,12 @@ class MetalKernel {
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool ensure_row_contiguous)
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
@@ -84,6 +86,7 @@ class MetalKernel {
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -91,5 +94,6 @@ class MetalKernel {
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
bool ensure_row_contiguous_ = true;
|
||||
bool atomic_outputs_ = false;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -257,14 +259,16 @@ class CustomKernel : public Primitive {
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
bool ensure_row_contiguous)
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value)
|
||||
: Primitive(stream),
|
||||
source_(source),
|
||||
name_(name),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(shape_infos),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -283,6 +287,7 @@ class CustomKernel : public Primitive {
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
Reference in New Issue
Block a user