2024-11-22 11:53:00 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
2024-05-23 03:57:13 +08:00
|
|
|
|
|
|
|
#include "mlx/backend/metal/kernels.h"
|
|
|
|
#include "mlx/backend/metal/utils.h"
|
|
|
|
#include "mlx/primitives.h"
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
|
2024-05-24 07:23:44 +08:00
|
|
|
MTL::ComputePipelineState* get_arange_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2024-05-23 03:57:13 +08:00
|
|
|
MTL::ComputePipelineState* get_unary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-13 05:22:12 +08:00
|
|
|
Dtype,
|
2024-10-16 07:23:15 +08:00
|
|
|
Dtype,
|
2024-06-13 05:22:12 +08:00
|
|
|
const std::string) {
|
2024-05-23 03:57:13 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_binary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-13 05:22:12 +08:00
|
|
|
Dtype,
|
|
|
|
Dtype,
|
|
|
|
const std::string) {
|
2024-05-23 03:57:13 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-13 05:22:12 +08:00
|
|
|
Dtype,
|
|
|
|
Dtype,
|
|
|
|
const std::string) {
|
2024-05-23 03:57:13 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_ternary_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-06-13 05:22:12 +08:00
|
|
|
Dtype,
|
|
|
|
const std::string) {
|
2024-05-23 03:57:13 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_copy_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-05-24 07:23:44 +08:00
|
|
|
const array&,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2025-01-08 06:02:16 +08:00
|
|
|
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2024-05-24 07:23:44 +08:00
|
|
|
MTL::ComputePipelineState* get_softmax_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
bool,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2025-03-31 22:36:55 +08:00
|
|
|
MTL::ComputePipelineState* get_logsumexp_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2024-05-24 07:23:44 +08:00
|
|
|
MTL::ComputePipelineState* get_scan_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
bool,
|
|
|
|
bool,
|
2024-06-06 22:24:58 +08:00
|
|
|
const std::string&,
|
2024-05-24 07:23:44 +08:00
|
|
|
const array&,
|
|
|
|
const array&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_sort_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const array&,
|
|
|
|
int,
|
|
|
|
int) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_mb_sort_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const array&,
|
|
|
|
int,
|
|
|
|
int) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-11-05 14:25:16 +08:00
|
|
|
const std::string&,
|
|
|
|
const std::string&,
|
2024-11-22 11:53:00 +08:00
|
|
|
const Dtype&) {
|
2024-05-24 07:23:44 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_reduce_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2024-05-29 06:18:18 +08:00
|
|
|
const std::string&,
|
2024-09-05 05:03:10 +08:00
|
|
|
const std::string&,
|
2024-11-22 11:53:00 +08:00
|
|
|
const Dtype&,
|
|
|
|
const Dtype&,
|
|
|
|
const std::string&,
|
2024-09-05 05:03:10 +08:00
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int) {
|
2024-05-23 03:57:13 +08:00
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2024-05-24 09:07:34 +08:00
|
|
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array&,
|
|
|
|
bool,
|
|
|
|
bool,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int) {
|
2025-06-07 11:08:15 +08:00
|
|
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
2024-05-24 09:07:34 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const array&,
|
|
|
|
bool,
|
|
|
|
bool,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool,
|
|
|
|
bool) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const array&,
|
|
|
|
bool) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const std::optional<array>& mask_out,
|
|
|
|
const std::optional<array>& mask_op,
|
|
|
|
bool,
|
|
|
|
bool,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool,
|
|
|
|
bool) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2025-04-15 07:37:36 +08:00
|
|
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array&,
|
|
|
|
bool,
|
|
|
|
bool,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool) {
|
2025-06-07 11:08:15 +08:00
|
|
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
2025-04-15 07:37:36 +08:00
|
|
|
}
|
|
|
|
|
2024-08-08 04:38:07 +08:00
|
|
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
const std::optional<array>&,
|
|
|
|
const std::optional<array>&,
|
|
|
|
bool,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2024-05-24 09:07:34 +08:00
|
|
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const array&,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
2025-06-11 11:58:16 +08:00
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
2024-05-24 09:07:34 +08:00
|
|
|
const array&,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int) {
|
2025-06-11 11:58:16 +08:00
|
|
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
2024-05-24 09:07:34 +08:00
|
|
|
}
|
|
|
|
|
2024-06-07 03:57:25 +08:00
|
|
|
MTL::ComputePipelineState* get_fft_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
2024-06-13 00:47:12 +08:00
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const std::string&) {
|
2025-06-07 11:08:15 +08:00
|
|
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
2024-06-07 03:57:25 +08:00
|
|
|
}
|
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
MTL::ComputePipelineState* get_quantized_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string&) {
|
|
|
|
return d.get_kernel(kernel_name);
|
|
|
|
}
|
|
|
|
|
2025-04-18 04:53:11 +08:00
|
|
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const array&,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
int,
|
|
|
|
bool) {
|
2025-06-07 11:08:15 +08:00
|
|
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
2025-04-18 04:53:11 +08:00
|
|
|
}
|
|
|
|
|
2025-05-31 21:16:14 +08:00
|
|
|
MTL::ComputePipelineState* get_paged_attention_kernel(
|
|
|
|
metal::Device& d,
|
|
|
|
const std::string& kernel_name,
|
|
|
|
const std::string& hash_name,
|
|
|
|
const metal::MTLFCList& func_consts,
|
|
|
|
const std::string&) {
|
|
|
|
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
|
|
|
}
|
|
|
|
|
2024-05-23 03:57:13 +08:00
|
|
|
} // namespace mlx::core
|