mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Refactored four-step
This commit is contained in:
parent
da98e8bce8
commit
6593281d25
@ -117,6 +117,7 @@ struct OldFFTPlan {
|
|||||||
class FFTPlan {
|
class FFTPlan {
|
||||||
public:
|
public:
|
||||||
enum FFTType {
|
enum FFTType {
|
||||||
|
UNSUPPORTED,
|
||||||
NOOP,
|
NOOP,
|
||||||
STOCKHAM,
|
STOCKHAM,
|
||||||
RADER,
|
RADER,
|
||||||
@ -137,6 +138,8 @@ class FFTPlan {
|
|||||||
type_ = SMALL_FOUR_STEP;
|
type_ = SMALL_FOUR_STEP;
|
||||||
n2_ = n > 65536 ? 1024 : 64;
|
n2_ = n > 65536 ? 1024 : 64;
|
||||||
n1_ = n / n2_;
|
n1_ = n / n2_;
|
||||||
|
steps1_ = stockham_decompose(n1_);
|
||||||
|
steps2_ = stockham_decompose(n2_);
|
||||||
} else {
|
} else {
|
||||||
type_ = LARGE_FOUR_STEP;
|
type_ = LARGE_FOUR_STEP;
|
||||||
}
|
}
|
||||||
@ -156,6 +159,7 @@ class FFTPlan {
|
|||||||
|
|
||||||
// throw for now but we have rader and bluestein to do
|
// throw for now but we have rader and bluestein to do
|
||||||
else {
|
else {
|
||||||
|
type_ = UNSUPPORTED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,12 +175,30 @@ class FFTPlan {
|
|||||||
return steps_;
|
return steps_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int first_size() const {
|
||||||
|
return n1_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<int>& first_steps() const {
|
||||||
|
return steps1_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int second_size() const {
|
||||||
|
return n2_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<int>& second_steps() const {
|
||||||
|
return steps2_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int n_;
|
int n_;
|
||||||
FFTType type_;
|
FFTType type_;
|
||||||
std::vector<int> steps_;
|
std::vector<int> steps_;
|
||||||
int n1_;
|
int n1_;
|
||||||
|
std::vector<int> steps1_;
|
||||||
int n2_;
|
int n2_;
|
||||||
|
std::vector<int> steps2_;
|
||||||
int bluestein_n_;
|
int bluestein_n_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -997,11 +1019,11 @@ void fft_stockham_inplace(
|
|||||||
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||||
auto& steps = plan.steps();
|
auto& steps = plan.steps();
|
||||||
int elems_per_thread = compute_elems_per_thread(n, steps);
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||||
int threads_per_fft = ceildiv(plan.size(), elems_per_thread);
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||||
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / plan.size(), 1);
|
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / n, 1);
|
||||||
int tg_mem_size = next_power_of_2(tg_batch_size * plan.size());
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||||
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||||
batch_size = real ? ceildiv(batch_size, 2) : batch_size;
|
batch_size = real ? ceildiv(batch_size, 2) : batch_size; // 2 RFFTs at once
|
||||||
std::vector<MTLFC> func_consts = {
|
std::vector<MTLFC> func_consts = {
|
||||||
{&inverse, MTL::DataType::DataTypeBool, 0},
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||||
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||||
@ -1029,13 +1051,99 @@ void fft_stockham_inplace(
|
|||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
compute_encoder.set_bytes(n, 2);
|
compute_encoder.set_bytes(n, 2);
|
||||||
compute_encoder.set_bytes(batch_size, 3);
|
compute_encoder.set_bytes(total_batch_size, 3);
|
||||||
|
|
||||||
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||||
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fft_four_step_inplace(
|
||||||
|
const FFTPlan& plan,
|
||||||
|
const array& in_,
|
||||||
|
array& out,
|
||||||
|
size_t axis,
|
||||||
|
bool inverse,
|
||||||
|
bool real,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||||
|
// Possibly copy the input but never the output as it doesn't have anything
|
||||||
|
// useful in it yet.
|
||||||
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||||
|
prepare_output_array(in, out, axis);
|
||||||
|
|
||||||
|
// Also prepare the intermediate array for the four-step fft which is
|
||||||
|
// implemented with 2 kernel calls.
|
||||||
|
array intermediate(
|
||||||
|
(real && inverse) ? out.shape() : in.shape(), complex64, nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
prepare_output_array(in, intermediate, axis);
|
||||||
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
|
// Make the two calls
|
||||||
|
for (int step = 0; step < 2; step++) {
|
||||||
|
// Create the parameters
|
||||||
|
int n1 = plan.first_size();
|
||||||
|
int n2 = plan.second_size();
|
||||||
|
int n = (step == 0) ? n1 : n2;
|
||||||
|
bool power_of_2 = true;
|
||||||
|
int total_batch_size =
|
||||||
|
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||||
|
auto& steps = (step == 0) ? plan.first_steps() : plan.second_steps();
|
||||||
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||||
|
int threads_per_fft = ceildiv(n, elems_per_thread);
|
||||||
|
int tg_batch_size =
|
||||||
|
std::max(MIN_THREADGROUP_MEM_SIZE / n, MIN_COALESCE_WIDTH);
|
||||||
|
int tg_mem_size = next_power_of_2(tg_batch_size * n);
|
||||||
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||||
|
std::vector<MTLFC> func_consts = {
|
||||||
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||||
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||||
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||||
|
for (int i = 0; i < steps.size(); i++) {
|
||||||
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||||
|
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||||
|
std::string hash_name;
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
hash_name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"four_step_mem_",
|
||||||
|
tg_mem_size,
|
||||||
|
"_",
|
||||||
|
in_type,
|
||||||
|
"_",
|
||||||
|
out_type,
|
||||||
|
"_",
|
||||||
|
step,
|
||||||
|
(real ? "_true" : "_false"));
|
||||||
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||||
|
auto template_def = get_template_definition(
|
||||||
|
kname, "four_step_fft", tg_mem_size, in_type, out_type, step, real);
|
||||||
|
auto kernel =
|
||||||
|
get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||||
|
|
||||||
|
// Launch it
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array((step == 0) ? in : intermediate, 0);
|
||||||
|
compute_encoder.set_output_array((step == 0) ? intermediate : out, 1);
|
||||||
|
compute_encoder.set_bytes(n1, 2);
|
||||||
|
compute_encoder.set_bytes(n2, 3);
|
||||||
|
compute_encoder.set_bytes(total_batch_size, 4);
|
||||||
|
|
||||||
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||||
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void fft_op_inplace(
|
void fft_op_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
@ -1053,10 +1161,17 @@ void fft_op_inplace(
|
|||||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||||
break;
|
break;
|
||||||
case FFTPlan::STOCKHAM:
|
case FFTPlan::STOCKHAM:
|
||||||
fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
return fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||||
break;
|
case FFTPlan::SMALL_FOUR_STEP:
|
||||||
|
return fft_four_step_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||||
|
case FFTPlan::UNSUPPORTED: {
|
||||||
|
std::string msg;
|
||||||
|
concatenate(msg, "FFT of size ", plan.size(), " not supported");
|
||||||
|
throw std::runtime_error(msg);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
std::cout << "----- NYI ----" << std::endl;
|
std::cout << "----- NYI ----" << std::endl;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user