mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add single kernel bluestein
This commit is contained in:
parent
6593281d25
commit
2a41caa00e
@ -122,6 +122,7 @@ class FFTPlan {
|
||||
STOCKHAM,
|
||||
RADER,
|
||||
BLUESTEIN,
|
||||
MULTIUPLOAD_BLUESTEIN,
|
||||
SMALL_FOUR_STEP,
|
||||
LARGE_FOUR_STEP
|
||||
};
|
||||
@ -132,7 +133,7 @@ class FFTPlan {
|
||||
type_ = NOOP;
|
||||
}
|
||||
|
||||
// Four step fft
|
||||
// Too large for Stockham so do four step fft for powers of 2
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
if (n <= 1 << 20) {
|
||||
type_ = SMALL_FOUR_STEP;
|
||||
@ -145,9 +146,9 @@ class FFTPlan {
|
||||
}
|
||||
}
|
||||
|
||||
// Bluestein fft
|
||||
// Too large and not power of 2 so do multi-upload Bluestein fft
|
||||
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
type_ = BLUESTEIN;
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
}
|
||||
|
||||
@ -157,9 +158,15 @@ class FFTPlan {
|
||||
steps_ = steps;
|
||||
}
|
||||
|
||||
// throw for now but we have rader and bluestein to do
|
||||
else {
|
||||
type_ = UNSUPPORTED;
|
||||
// Add rader but for now simply fall back to bluestein when stockham not
|
||||
// posssible
|
||||
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
||||
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
} else {
|
||||
type_ = BLUESTEIN;
|
||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||
steps_ = stockham_decompose(bluestein_n_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,6 +198,10 @@ class FFTPlan {
|
||||
return steps2_;
|
||||
}
|
||||
|
||||
int bluestein_size() const {
|
||||
return bluestein_n_;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_;
|
||||
FFTType type_;
|
||||
@ -1144,6 +1155,145 @@ void fft_four_step_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void fft_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||
// Possibly copy the input but never the output as it doesn't have anything
|
||||
// useful in it yet.
|
||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||
prepare_output_array(in, out, axis);
|
||||
|
||||
// Prepare the arguments for bluestein fft
|
||||
int n = plan.bluestein_size();
|
||||
bool power_of_2 = true;
|
||||
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
||||
: in.size() / plan.size();
|
||||
auto& steps = plan.steps();
|
||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||
std::vector<MTLFC> func_consts = {
|
||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
||||
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||
|
||||
// Get the bluestein constants
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Launch it
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_input_array(w_q, 2);
|
||||
compute_encoder.set_input_array(w_k, 3);
|
||||
compute_encoder.set_bytes(plan.size(), 4);
|
||||
compute_encoder.set_bytes(n, 5);
|
||||
compute_encoder.set_bytes(total_batch_size, 6);
|
||||
|
||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void fft_multi_upload_bluestein(
|
||||
const FFTPlan& plan,
|
||||
const array& in_,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Get Bluestein's constants using the CPU (this is done in the submission
|
||||
// thread which is pretty bad).
|
||||
auto [w_k, w_q] =
|
||||
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||
d.add_temporary(w_k, s.index);
|
||||
d.add_temporary(w_q, s.index);
|
||||
|
||||
// Prepare the input
|
||||
auto in_shape = inverse ? out.shape() : in_.shape();
|
||||
array in(std::move(in_shape), complex64, nullptr, {});
|
||||
if (real && !inverse) {
|
||||
copy_gpu(
|
||||
in_,
|
||||
in,
|
||||
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
Shape rstarts(in.ndim(), 0);
|
||||
Shape rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
||||
d.add_temporary(conj_temp, s.index);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in_}, in, "Conjugate", s);
|
||||
d.add_temporary(in, s.index);
|
||||
} else {
|
||||
in.copy_shared_buffer(in_);
|
||||
}
|
||||
|
||||
// Multiply with
|
||||
Strides b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
array x(in.shape(), complex64, nullptr, {});
|
||||
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
||||
d.add_temporary(x, s.index);
|
||||
|
||||
// Pad
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_size();
|
||||
array padded_x(padded_shape, complex64, nullptr, {});
|
||||
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
||||
d.add_temporary(zero, s.index);
|
||||
d.add_temporary(padded_x, s.index);
|
||||
|
||||
// First fft
|
||||
}
|
||||
|
||||
void fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
@ -1164,6 +1314,8 @@ void fft_op_inplace(
|
||||
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::SMALL_FOUR_STEP:
|
||||
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::BLUESTEIN:
|
||||
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
||||
case FFTPlan::UNSUPPORTED: {
|
||||
std::string msg;
|
||||
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||
|
Loading…
Reference in New Issue
Block a user