2024-05-22 12:57:13 -07:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
2025-02-06 11:11:22 -08:00
|
|
|
#pragma once
|
|
|
|
|
2024-06-12 09:47:12 -07:00
|
|
|
#include <fmt/format.h>
|
|
|
|
|
2024-05-22 12:57:13 -07:00
|
|
|
#include "mlx/array.h"
|
|
|
|
#include "mlx/backend/metal/device.h"
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
2024-05-23 16:23:44 -07:00
|
|
|
MTL::ComputePipelineState* get_arange_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& out);
|
|
|
|
|
2024-05-22 12:57:13 -07:00
|
|
|
MTL::ComputePipelineState* get_unary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-10-15 16:23:15 -07:00
|
|
|
Dtype in_type,
|
2024-06-12 14:22:12 -07:00
|
|
|
Dtype out_type,
|
2025-07-15 06:06:35 +09:00
|
|
|
const char* op);
|
2024-05-22 12:57:13 -07:00
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_binary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-12 14:22:12 -07:00
|
|
|
Dtype in_type,
|
|
|
|
Dtype out_type,
|
2025-07-15 06:06:35 +09:00
|
|
|
const char* op);
|
2024-05-22 12:57:13 -07:00
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-12 14:22:12 -07:00
|
|
|
Dtype in_type,
|
|
|
|
Dtype out_type,
|
2025-07-15 06:06:35 +09:00
|
|
|
const char* op);
|
2024-05-22 12:57:13 -07:00
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_ternary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-12 14:22:12 -07:00
|
|
|
Dtype type,
|
2025-07-15 06:06:35 +09:00
|
|
|
const char* op);
|
2024-05-22 12:57:13 -07:00
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_copy_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& out);
|
|
|
|
|
2025-01-07 14:02:16 -08:00
|
|
|
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& out);
|
|
|
|
|
2024-05-23 16:23:44 -07:00
|
|
|
MTL::ComputePipelineState* get_softmax_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
bool precise,
|
|
|
|
const array& out);
|
|
|
|
|
2025-03-31 07:36:55 -07:00
|
|
|
MTL::ComputePipelineState* get_logsumexp_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& out);
|
|
|
|
|
2024-05-23 16:23:44 -07:00
|
|
|
MTL::ComputePipelineState* get_scan_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
bool reverse,
|
|
|
|
bool inclusive,
|
2024-06-06 07:24:58 -07:00
|
|
|
const std::string& reduce_type,
|
2024-05-23 16:23:44 -07:00
|
|
|
const array& in,
|
|
|
|
const array& out);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_sort_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& out,
|
|
|
|
int bn,
|
|
|
|
int tn);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_mb_sort_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& idx,
|
|
|
|
int bn,
|
|
|
|
int tn);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-11-04 22:25:16 -08:00
|
|
|
const std::string& func_name,
|
|
|
|
const std::string& op_name,
|
2024-11-21 19:53:00 -08:00
|
|
|
const Dtype& out_type);
|
2024-05-23 16:23:44 -07:00
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_reduce_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-09-04 14:03:10 -07:00
|
|
|
const std::string& func_name,
|
2024-05-28 15:18:18 -07:00
|
|
|
const std::string& op_name,
|
2024-11-21 19:53:00 -08:00
|
|
|
const Dtype& in_type,
|
|
|
|
const Dtype& out_type,
|
|
|
|
const std::string& idx_t,
|
2024-09-04 14:03:10 -07:00
|
|
|
int ndim = -1,
|
|
|
|
int bm = -1,
|
|
|
|
int bn = -1);
|
2024-05-23 16:23:44 -07:00
|
|
|
|
2024-05-23 18:07:34 -07:00
|
|
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array& out,
|
|
|
|
bool transpose_a,
|
|
|
|
bool transpose_b,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& out,
|
|
|
|
bool transpose_a,
|
|
|
|
bool transpose_b,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn,
|
|
|
|
bool mn_aligned,
|
|
|
|
bool k_aligned);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& in,
|
|
|
|
const array& out,
|
|
|
|
bool axbpy);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& out,
|
|
|
|
const std::optional<array>& mask_out,
|
|
|
|
const std::optional<array>& mask_op,
|
|
|
|
bool transpose_a,
|
|
|
|
bool transpose_b,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn,
|
|
|
|
bool mn_aligned,
|
|
|
|
bool k_aligned);
|
|
|
|
|
2025-04-14 16:37:36 -07:00
|
|
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array& out,
|
|
|
|
bool transpose_a,
|
|
|
|
bool transpose_b,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn,
|
|
|
|
bool rhs);
|
|
|
|
|
2025-07-07 17:59:53 -07:00
|
|
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array& out,
|
|
|
|
bool transpose_a,
|
|
|
|
bool transpose_b,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn);
|
|
|
|
|
2024-05-23 18:07:34 -07:00
|
|
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& out,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn,
|
|
|
|
int n_channel_specialization,
|
|
|
|
bool small_filter);
|
|
|
|
|
2024-08-07 13:38:07 -07:00
|
|
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array& out,
|
|
|
|
const std::optional<array>& mask_out,
|
|
|
|
const std::optional<array>& mask_op,
|
|
|
|
bool transpose_mat,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int sm,
|
|
|
|
int sn,
|
|
|
|
int tm,
|
|
|
|
int tn,
|
|
|
|
bool contiguous);
|
|
|
|
|
2024-05-23 18:07:34 -07:00
|
|
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2025-06-10 20:58:16 -07:00
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
2024-05-23 18:07:34 -07:00
|
|
|
const array& out,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn);
|
|
|
|
|
2024-06-06 12:57:25 -07:00
|
|
|
MTL::ComputePipelineState* get_fft_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
2024-06-12 09:47:12 -07:00
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const std::string& template_def);
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_quantized_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& template_def);
|
|
|
|
|
2025-04-17 13:53:11 -07:00
|
|
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array& x,
|
|
|
|
int group_size,
|
|
|
|
int bits,
|
|
|
|
int bm,
|
|
|
|
int bn,
|
|
|
|
int bk,
|
|
|
|
int wm,
|
|
|
|
int wn,
|
|
|
|
bool transpose);
|
|
|
|
|
2024-06-12 09:47:12 -07:00
|
|
|
// Create a GPU kernel template definition for JIT compilation
|
|
|
|
template <typename... Args>
|
2025-07-15 06:06:35 +09:00
|
|
|
std::string get_template_definition(
|
|
|
|
std::string_view name,
|
|
|
|
std::string_view func,
|
|
|
|
Args... args) {
|
2024-06-12 09:47:12 -07:00
|
|
|
std::ostringstream s;
|
|
|
|
s << func << "<";
|
|
|
|
bool first = true;
|
|
|
|
auto add_arg = [&s, &first](const auto& arg) {
|
|
|
|
if (!first) {
|
|
|
|
s << ", ";
|
|
|
|
}
|
|
|
|
first = false;
|
|
|
|
s << arg;
|
|
|
|
};
|
|
|
|
(add_arg(args), ...);
|
|
|
|
s << ">";
|
2024-10-24 08:54:51 -07:00
|
|
|
return fmt::format(
|
|
|
|
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
|
|
|
name,
|
|
|
|
s.str());
|
2024-06-12 09:47:12 -07:00
|
|
|
}
|
2024-06-06 12:57:25 -07:00
|
|
|
|
2024-05-22 12:57:13 -07:00
|
|
|
} // namespace mlx::core
|