mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Feature complete Metal FFT (#1102)
* feature complete metal fft * fix contiguity bug * jit fft * simplify rader/bluestein constant computation * remove kernel/utils.h dep * remove bf16.h dep * format --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/fft.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@@ -489,4 +490,51 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
if (lib_name.find("bluestein") != std::string::npos) {
|
||||
kernel_string = bluestein_fft_kernel;
|
||||
} else if (lib_name.find("rader") != std::string::npos) {
|
||||
kernel_string = rader_fft_kernel;
|
||||
} else if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_string = four_step_fft_kernel;
|
||||
} else {
|
||||
kernel_string = fft_kernel;
|
||||
}
|
||||
kernel_source << metal::fft();
|
||||
if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type,
|
||||
"step"_a = step,
|
||||
"real"_a = real);
|
||||
} else {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user