MLX
Loading...
Searching...
No Matches
fft.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view fft_kernel = R"(
4template [[host_name("{name}")]] [[kernel]] void
5fft<{tg_mem_size}, {in_T}, {out_T}>(
6 const device {in_T}* in [[buffer(0)]],
7 device {out_T}* out [[buffer(1)]],
8 constant const int& n,
9 constant const int& batch_size,
10 uint3 elem [[thread_position_in_grid]],
11 uint3 grid [[threads_per_grid]]);
12)";
13
14constexpr std::string_view rader_fft_kernel = R"(
15template [[host_name("{name}")]] [[kernel]] void
16rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
17 const device {in_T}* in [[buffer(0)]],
18 device {out_T}* out [[buffer(1)]],
19 const device float2* raders_b_q [[buffer(2)]],
20 const device short* raders_g_q [[buffer(3)]],
21 const device short* raders_g_minus_q [[buffer(4)]],
22 constant const int& n,
23 constant const int& batch_size,
24 constant const int& rader_n,
25 uint3 elem [[thread_position_in_grid]],
26 uint3 grid [[threads_per_grid]]);
27)";
28
29constexpr std::string_view bluestein_fft_kernel = R"(
30template [[host_name("{name}")]] [[kernel]] void
31bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
32 const device {in_T}* in [[buffer(0)]],
33 device {out_T}* out [[buffer(1)]],
34 const device float2* w_q [[buffer(2)]],
35 const device float2* w_k [[buffer(3)]],
36 constant const int& length,
37 constant const int& n,
38 constant const int& batch_size,
39 uint3 elem [[thread_position_in_grid]],
40 uint3 grid [[threads_per_grid]]);
41)";
42
43constexpr std::string_view four_step_fft_kernel = R"(
44template [[host_name("{name}")]] [[kernel]] void
45four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
46 const device {in_T}* in [[buffer(0)]],
47 device {out_T}* out [[buffer(1)]],
48 constant const int& n1,
49 constant const int& n2,
50 constant const int& batch_size,
51 uint3 elem [[thread_position_in_grid]],
52 uint3 grid [[threads_per_grid]]);
53)";
constexpr std::string_view fft_kernel
Definition fft.h:3
constexpr std::string_view bluestein_fft_kernel
Definition fft.h:29
constexpr std::string_view rader_fft_kernel
Definition fft.h:14
constexpr std::string_view four_step_fft_kernel
Definition fft.h:43