4template [[host_name("{name}")]] [[kernel]] void
5fft<{tg_mem_size}, {in_T}, {out_T}>(
6 const device {in_T}* in [[buffer(0)]],
7 device {out_T}* out [[buffer(1)]],
9 constant const int& batch_size,
10 uint3 elem [[thread_position_in_grid]],
11 uint3 grid [[threads_per_grid]]);
15template [[host_name("{name}")]] [[kernel]] void
16rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
17 const device {in_T}* in [[buffer(0)]],
18 device {out_T}* out [[buffer(1)]],
19 const device float2* raders_b_q [[buffer(2)]],
20 const device short* raders_g_q [[buffer(3)]],
21 const device short* raders_g_minus_q [[buffer(4)]],
22 constant const int& n,
23 constant const int& batch_size,
24 constant const int& rader_n,
25 uint3 elem [[thread_position_in_grid]],
26 uint3 grid [[threads_per_grid]]);
30template [[host_name("{name}")]] [[kernel]] void
31bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
32 const device {in_T}* in [[buffer(0)]],
33 device {out_T}* out [[buffer(1)]],
34 const device float2* w_q [[buffer(2)]],
35 const device float2* w_k [[buffer(3)]],
36 constant const int& length,
37 constant const int& n,
38 constant const int& batch_size,
39 uint3 elem [[thread_position_in_grid]],
40 uint3 grid [[threads_per_grid]]);
44template [[host_name("{name}")]] [[kernel]] void
45four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
46 const device {in_T}* in [[buffer(0)]],
47 device {out_T}* out [[buffer(1)]],
48 constant const int& n1,
49 constant const int& n2,
50 constant const int& batch_size,
51 uint3 elem [[thread_position_in_grid]],
52 uint3 grid [[threads_per_grid]]);