mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Add single kernel bluestein
This commit is contained in:
parent
6593281d25
commit
2a41caa00e
@ -122,6 +122,7 @@ class FFTPlan {
|
|||||||
STOCKHAM,
|
STOCKHAM,
|
||||||
RADER,
|
RADER,
|
||||||
BLUESTEIN,
|
BLUESTEIN,
|
||||||
|
MULTIUPLOAD_BLUESTEIN,
|
||||||
SMALL_FOUR_STEP,
|
SMALL_FOUR_STEP,
|
||||||
LARGE_FOUR_STEP
|
LARGE_FOUR_STEP
|
||||||
};
|
};
|
||||||
@ -132,7 +133,7 @@ class FFTPlan {
|
|||||||
type_ = NOOP;
|
type_ = NOOP;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Four step fft
|
// Too large for Stockham so do four step fft for powers of 2
|
||||||
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||||
if (n <= 1 << 20) {
|
if (n <= 1 << 20) {
|
||||||
type_ = SMALL_FOUR_STEP;
|
type_ = SMALL_FOUR_STEP;
|
||||||
@ -145,9 +146,9 @@ class FFTPlan {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bluestein fft
|
// Too large and not power of 2 so do multi-upload Bluestein fft
|
||||||
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||||
type_ = BLUESTEIN;
|
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||||
bluestein_n_ = next_fast_n(2 * n - 1);
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,9 +158,15 @@ class FFTPlan {
|
|||||||
steps_ = steps;
|
steps_ = steps;
|
||||||
}
|
}
|
||||||
|
|
||||||
// throw for now but we have rader and bluestein to do
|
// Add rader but for now simply fall back to bluestein when stockham not
|
||||||
else {
|
// posssible
|
||||||
type_ = UNSUPPORTED;
|
else if (n > MAX_BLUESTEIN_FFT_SIZE) {
|
||||||
|
type_ = MULTIUPLOAD_BLUESTEIN;
|
||||||
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||||
|
} else {
|
||||||
|
type_ = BLUESTEIN;
|
||||||
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||||
|
steps_ = stockham_decompose(bluestein_n_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,6 +198,10 @@ class FFTPlan {
|
|||||||
return steps2_;
|
return steps2_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int bluestein_size() const {
|
||||||
|
return bluestein_n_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int n_;
|
int n_;
|
||||||
FFTType type_;
|
FFTType type_;
|
||||||
@ -1144,6 +1155,145 @@ void fft_four_step_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fft_bluestein(
|
||||||
|
const FFTPlan& plan,
|
||||||
|
const array& in_,
|
||||||
|
array& out,
|
||||||
|
size_t axis,
|
||||||
|
bool inverse,
|
||||||
|
bool real,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||||
|
// Possibly copy the input but never the output as it doesn't have anything
|
||||||
|
// useful in it yet.
|
||||||
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||||
|
prepare_output_array(in, out, axis);
|
||||||
|
|
||||||
|
// Prepare the arguments for bluestein fft
|
||||||
|
int n = plan.bluestein_size();
|
||||||
|
bool power_of_2 = true;
|
||||||
|
int total_batch_size = out.dtype() == float32 ? out.size() / plan.size()
|
||||||
|
: in.size() / plan.size();
|
||||||
|
auto& steps = plan.steps();
|
||||||
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||||
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||||
|
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||||
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||||
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||||
|
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||||
|
std::vector<MTLFC> func_consts = {
|
||||||
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||||
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||||
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||||
|
for (int i = 0; i < steps.size(); i++) {
|
||||||
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||||
|
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||||
|
std::string hash_name;
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
hash_name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
kname, "bluestein_fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||||
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||||
|
auto template_def = get_template_definition(
|
||||||
|
kname, "bluestein_fft", tg_mem_size, in_type, out_type);
|
||||||
|
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||||
|
|
||||||
|
// Get the bluestein constants
|
||||||
|
auto [w_k, w_q] =
|
||||||
|
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||||
|
d.add_temporary(w_k, s.index);
|
||||||
|
d.add_temporary(w_q, s.index);
|
||||||
|
|
||||||
|
// Launch it
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
compute_encoder.set_input_array(w_q, 2);
|
||||||
|
compute_encoder.set_input_array(w_k, 3);
|
||||||
|
compute_encoder.set_bytes(plan.size(), 4);
|
||||||
|
compute_encoder.set_bytes(n, 5);
|
||||||
|
compute_encoder.set_bytes(total_batch_size, 6);
|
||||||
|
|
||||||
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||||
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fft_multi_upload_bluestein(
|
||||||
|
const FFTPlan& plan,
|
||||||
|
const array& in_,
|
||||||
|
array& out,
|
||||||
|
size_t axis,
|
||||||
|
bool inverse,
|
||||||
|
bool real,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Get Bluestein's constants using the CPU (this is done in the submission
|
||||||
|
// thread which is pretty bad).
|
||||||
|
auto [w_k, w_q] =
|
||||||
|
compute_bluestein_constants(plan.size(), plan.bluestein_size());
|
||||||
|
d.add_temporary(w_k, s.index);
|
||||||
|
d.add_temporary(w_q, s.index);
|
||||||
|
|
||||||
|
// Prepare the input
|
||||||
|
auto in_shape = inverse ? out.shape() : in_.shape();
|
||||||
|
array in(std::move(in_shape), complex64, nullptr, {});
|
||||||
|
if (real && !inverse) {
|
||||||
|
copy_gpu(
|
||||||
|
in_,
|
||||||
|
in,
|
||||||
|
in_.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
s);
|
||||||
|
d.add_temporary(in, s.index);
|
||||||
|
} else if (real && inverse) {
|
||||||
|
int back_offset = plan.size() % 2 == 0 ? 2 : 1;
|
||||||
|
auto slice_shape = in.shape();
|
||||||
|
slice_shape[axis] -= back_offset;
|
||||||
|
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||||
|
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||||
|
Shape rstarts(in.ndim(), 0);
|
||||||
|
Shape rstrides(in.ndim(), 1);
|
||||||
|
rstarts[axis] = in.shape(axis) - back_offset;
|
||||||
|
rstrides[axis] = -1;
|
||||||
|
unary_op_gpu({in_}, conj_temp, "Conjugate", s);
|
||||||
|
slice_gpu(in_, slice_temp, rstarts, rstrides, s);
|
||||||
|
concatenate_gpu({conj_temp, slice_temp}, in, (int)axis, s);
|
||||||
|
d.add_temporary(conj_temp, s.index);
|
||||||
|
} else if (inverse) {
|
||||||
|
unary_op_gpu({in_}, in, "Conjugate", s);
|
||||||
|
d.add_temporary(in, s.index);
|
||||||
|
} else {
|
||||||
|
in.copy_shared_buffer(in_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply with
|
||||||
|
Strides b_strides(in.ndim(), 0);
|
||||||
|
b_strides[axis] = 1;
|
||||||
|
array w_k_broadcast(in.shape(), complex64, nullptr, {});
|
||||||
|
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||||
|
array x(in.shape(), complex64, nullptr, {});
|
||||||
|
binary_op_gpu({in, w_k_broadcast}, x, "Multiply", s);
|
||||||
|
d.add_temporary(x, s.index);
|
||||||
|
|
||||||
|
// Pad
|
||||||
|
auto padded_shape = out.shape();
|
||||||
|
padded_shape[axis] = plan.bluestein_size();
|
||||||
|
array padded_x(padded_shape, complex64, nullptr, {});
|
||||||
|
auto zero = array(complex64_t{0.0f, 0.0f});
|
||||||
|
pad_gpu(x, zero, padded_x, {(int)axis}, {0}, s);
|
||||||
|
d.add_temporary(zero, s.index);
|
||||||
|
d.add_temporary(padded_x, s.index);
|
||||||
|
|
||||||
|
// First fft
|
||||||
|
}
|
||||||
|
|
||||||
void fft_op_inplace(
|
void fft_op_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
@ -1164,6 +1314,8 @@ void fft_op_inplace(
|
|||||||
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||||
case FFTPlan::SMALL_FOUR_STEP:
|
case FFTPlan::SMALL_FOUR_STEP:
|
||||||
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||||
|
case FFTPlan::BLUESTEIN:
|
||||||
|
return fft_bluestein(plan, in, out, axis, inverse, real, d, s);
|
||||||
case FFTPlan::UNSUPPORTED: {
|
case FFTPlan::UNSUPPORTED: {
|
||||||
std::string msg;
|
std::string msg;
|
||||||
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||||
|
Loading…
Reference in New Issue
Block a user