mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 20:41:15 +08:00
Refactored stockham
This commit is contained in:
parent
be57a16a80
commit
da98e8bce8
@ -51,6 +51,35 @@ std::vector<int> prime_factors(int n) {
|
|||||||
return factors;
|
return factors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int next_fast_n(int n) {
|
||||||
|
return next_power_of_2(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> stockham_decompose(int n) {
|
||||||
|
auto radices = supported_radices();
|
||||||
|
std::vector<int> steps(radices.size(), 0);
|
||||||
|
int orig_n = n;
|
||||||
|
|
||||||
|
for (int i = 0; i < radices.size(); i++) {
|
||||||
|
int radix = radices[i];
|
||||||
|
|
||||||
|
// Manually tuned radices for powers of 2
|
||||||
|
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (n % radix == 0) {
|
||||||
|
steps[i] += 1;
|
||||||
|
n /= radix;
|
||||||
|
if (n == 1) {
|
||||||
|
return steps;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
struct FourStepParams {
|
struct FourStepParams {
|
||||||
bool required = false;
|
bool required = false;
|
||||||
bool first_step = true;
|
bool first_step = true;
|
||||||
@ -70,7 +99,7 @@ void fft_op(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
struct FFTPlan {
|
struct OldFFTPlan {
|
||||||
int n = 0;
|
int n = 0;
|
||||||
// Number of steps for each radix in the Stockham decomposition
|
// Number of steps for each radix in the Stockham decomposition
|
||||||
std::vector<int> stockham;
|
std::vector<int> stockham;
|
||||||
@ -85,9 +114,71 @@ struct FFTPlan {
|
|||||||
int n2 = 0;
|
int n2 = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
int next_fast_n(int n) {
|
class FFTPlan {
|
||||||
return next_power_of_2(n);
|
public:
|
||||||
}
|
enum FFTType {
|
||||||
|
NOOP,
|
||||||
|
STOCKHAM,
|
||||||
|
RADER,
|
||||||
|
BLUESTEIN,
|
||||||
|
SMALL_FOUR_STEP,
|
||||||
|
LARGE_FOUR_STEP
|
||||||
|
};
|
||||||
|
|
||||||
|
FFTPlan(int n) : n_(n) {
|
||||||
|
// NOOP
|
||||||
|
if (n == 1) {
|
||||||
|
type_ = NOOP;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Four step fft
|
||||||
|
else if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||||
|
if (n <= 1 << 20) {
|
||||||
|
type_ = SMALL_FOUR_STEP;
|
||||||
|
n2_ = n > 65536 ? 1024 : 64;
|
||||||
|
n1_ = n / n2_;
|
||||||
|
} else {
|
||||||
|
type_ = LARGE_FOUR_STEP;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bluestein fft
|
||||||
|
else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||||
|
type_ = BLUESTEIN;
|
||||||
|
bluestein_n_ = next_fast_n(2 * n - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stockham fft
|
||||||
|
else if (auto steps = stockham_decompose(n); steps.size() > 0) {
|
||||||
|
type_ = STOCKHAM;
|
||||||
|
steps_ = steps;
|
||||||
|
}
|
||||||
|
|
||||||
|
// throw for now but we have rader and bluestein to do
|
||||||
|
else {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FFTType type() const {
|
||||||
|
return type_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() const {
|
||||||
|
return n_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<int>& steps() const {
|
||||||
|
return steps_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int n_;
|
||||||
|
FFTType type_;
|
||||||
|
std::vector<int> steps_;
|
||||||
|
int n1_;
|
||||||
|
int n2_;
|
||||||
|
int bluestein_n_;
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<int> plan_stockham_fft(int n) {
|
std::vector<int> plan_stockham_fft(int n) {
|
||||||
auto radices = supported_radices();
|
auto radices = supported_radices();
|
||||||
@ -113,10 +204,10 @@ std::vector<int> plan_stockham_fft(int n) {
|
|||||||
throw std::runtime_error("Unplannable");
|
throw std::runtime_error("Unplannable");
|
||||||
}
|
}
|
||||||
|
|
||||||
FFTPlan plan_fft(int n) {
|
OldFFTPlan plan_fft(int n) {
|
||||||
auto radices = supported_radices();
|
auto radices = supported_radices();
|
||||||
|
|
||||||
FFTPlan plan;
|
OldFFTPlan plan;
|
||||||
plan.n = n;
|
plan.n = n;
|
||||||
plan.rader = std::vector<int>(radices.size(), 0);
|
plan.rader = std::vector<int>(radices.size(), 0);
|
||||||
|
|
||||||
@ -176,7 +267,7 @@ FFTPlan plan_fft(int n) {
|
|||||||
return plan;
|
return plan;
|
||||||
}
|
}
|
||||||
|
|
||||||
int compute_elems_per_thread(FFTPlan plan) {
|
int compute_elems_per_thread(OldFFTPlan plan) {
|
||||||
// Heuristics for selecting an efficient number
|
// Heuristics for selecting an efficient number
|
||||||
// of threads to use for a particular mixed-radix FFT.
|
// of threads to use for a particular mixed-radix FFT.
|
||||||
auto n = plan.n;
|
auto n = plan.n;
|
||||||
@ -359,7 +450,7 @@ void multi_upload_bluestein_fft(
|
|||||||
size_t axis,
|
size_t axis,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
FFTPlan& plan,
|
OldFFTPlan& plan,
|
||||||
std::vector<array>& copies,
|
std::vector<array>& copies,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
@ -488,7 +579,7 @@ void four_step_fft(
|
|||||||
size_t axis,
|
size_t axis,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real,
|
bool real,
|
||||||
FFTPlan& plan,
|
OldFFTPlan& plan,
|
||||||
std::vector<array>& copies,
|
std::vector<array>& copies,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
bool in_place) {
|
bool in_place) {
|
||||||
@ -771,6 +862,51 @@ void fft_op(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int compute_elems_per_thread(int n, const std::vector<int>& steps) {
|
||||||
|
auto radices = supported_radices();
|
||||||
|
std::set<int> used_radices;
|
||||||
|
for (int i = 0; i < steps.size(); i++) {
|
||||||
|
if (steps[i] > 0) {
|
||||||
|
used_radices.insert(radices[i % radices.size()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manual tuning for 7/11/13
|
||||||
|
if (used_radices.find(7) != used_radices.end() &&
|
||||||
|
(used_radices.find(11) != used_radices.end() ||
|
||||||
|
used_radices.find(13) != used_radices.end())) {
|
||||||
|
return 7;
|
||||||
|
} else if (
|
||||||
|
used_radices.find(11) != used_radices.end() &&
|
||||||
|
used_radices.find(13) != used_radices.end()) {
|
||||||
|
return 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(alexbarron) Some really weird stuff is going on
|
||||||
|
// for certain `elems_per_thread` on large composite n.
|
||||||
|
// Possibly a compiler issue?
|
||||||
|
if (n == 3159)
|
||||||
|
return 13;
|
||||||
|
if (n == 3645)
|
||||||
|
return 5;
|
||||||
|
if (n == 3969)
|
||||||
|
return 7;
|
||||||
|
if (n == 1982)
|
||||||
|
return 5;
|
||||||
|
|
||||||
|
if (used_radices.size() == 1) {
|
||||||
|
return *(used_radices.begin());
|
||||||
|
}
|
||||||
|
if (used_radices.size() == 2 &&
|
||||||
|
(used_radices.find(11) != used_radices.end() ||
|
||||||
|
used_radices.find(13) != used_radices.end())) {
|
||||||
|
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// In all other cases use the second smallest radix.
|
||||||
|
return *(++used_radices.begin());
|
||||||
|
}
|
||||||
|
|
||||||
inline array ensure_fastest_moving_axis(
|
inline array ensure_fastest_moving_axis(
|
||||||
const array& x,
|
const array& x,
|
||||||
int axis,
|
int axis,
|
||||||
@ -840,6 +976,7 @@ inline void prepare_output_array(const array& in, array& out, int axis) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void fft_stockham_inplace(
|
void fft_stockham_inplace(
|
||||||
|
const FFTPlan& plan,
|
||||||
const array& in_,
|
const array& in_,
|
||||||
array& out,
|
array& out,
|
||||||
size_t axis,
|
size_t axis,
|
||||||
@ -847,8 +984,56 @@ void fft_stockham_inplace(
|
|||||||
bool real,
|
bool real,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
|
// Prepare the input and output arrays such that `axis` has stride 1.
|
||||||
|
// Possibly copy the input but never the output as it doesn't have anything
|
||||||
|
// useful in it yet.
|
||||||
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
array in = ensure_fastest_moving_axis(in_, axis, d, s);
|
||||||
prepare_output_array(in, out, axis);
|
prepare_output_array(in, out, axis);
|
||||||
|
|
||||||
|
// Prepare the arguments for stockham fft
|
||||||
|
int n = plan.size();
|
||||||
|
bool power_of_2 = is_power_of_2(n);
|
||||||
|
int total_batch_size =
|
||||||
|
out.dtype() == float32 ? out.size() / n : in.size() / n;
|
||||||
|
auto& steps = plan.steps();
|
||||||
|
int elems_per_thread = compute_elems_per_thread(n, steps);
|
||||||
|
int threads_per_fft = ceildiv(plan.size(), elems_per_thread);
|
||||||
|
int tg_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / plan.size(), 1);
|
||||||
|
int tg_mem_size = next_power_of_2(tg_batch_size * plan.size());
|
||||||
|
int batch_size = ceildiv(total_batch_size, tg_batch_size);
|
||||||
|
batch_size = real ? ceildiv(batch_size, 2) : batch_size;
|
||||||
|
std::vector<MTLFC> func_consts = {
|
||||||
|
{&inverse, MTL::DataType::DataTypeBool, 0},
|
||||||
|
{&power_of_2, MTL::DataType::DataTypeBool, 1},
|
||||||
|
{&elems_per_thread, MTL::DataType::DataTypeInt, 2}};
|
||||||
|
for (int i = 0; i < steps.size(); i++) {
|
||||||
|
func_consts.emplace_back(&steps[i], MTL::DataType::DataTypeInt, 4 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||||
|
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||||
|
std::string hash_name;
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
hash_name.reserve(64);
|
||||||
|
concatenate(kname, "fft_mem_", tg_mem_size, "_", in_type, "_", out_type);
|
||||||
|
concatenate(hash_name, kname, "_n", n, "_inv_", inverse);
|
||||||
|
auto template_def =
|
||||||
|
get_template_definition(kname, "fft", tg_mem_size, in_type, out_type);
|
||||||
|
auto kernel = get_fft_kernel(d, kname, hash_name, func_consts, template_def);
|
||||||
|
|
||||||
|
// Launch it
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
compute_encoder.set_bytes(n, 2);
|
||||||
|
compute_encoder.set_bytes(batch_size, 3);
|
||||||
|
|
||||||
|
MTL::Size group_dims(1, tg_batch_size, threads_per_fft);
|
||||||
|
MTL::Size grid_dims(batch_size, tg_batch_size, threads_per_fft);
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fft_op_inplace(
|
void fft_op_inplace(
|
||||||
@ -860,15 +1045,18 @@ void fft_op_inplace(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
// Get the FFT size and plan it
|
// Get the FFT size and plan it
|
||||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
auto plan =
|
||||||
auto plan = plan_fft(n);
|
FFTPlan(out.dtype() == float32 ? out.shape(axis) : in.shape(axis));
|
||||||
if (n == 1) {
|
|
||||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (plan.four_step && plan.bluestein_n < 0) {
|
switch (plan.type()) {
|
||||||
// four_step_fft(in, out, axis, inverse, real, plan, inplace, d, s);
|
case FFTPlan::NOOP:
|
||||||
return;
|
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||||
|
break;
|
||||||
|
case FFTPlan::STOCKHAM:
|
||||||
|
fft_stockham_inplace(plan, in, out, axis, inverse, real, d, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
std::cout << "----- NYI ----" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user