mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
This commit is contained in:
16
mlx/fast.cpp
16
mlx/fast.cpp
@@ -949,6 +949,7 @@ void write_signature(
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
@@ -1042,8 +1043,14 @@ void write_signature(
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
kernel_source << " device " << get_type_string(dtype) << "* " << name
|
||||
<< " [[buffer(" << index << ")]]";
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source << "atomic<" << type_string << ">";
|
||||
} else {
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
@@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
@@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
|
||||
if (needs_template) {
|
||||
@@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_),
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
|
||||
Reference in New Issue
Block a user