mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix four step fft
This commit is contained in:
parent
2a41caa00e
commit
83762691ba
@ -1117,8 +1117,11 @@ void fft_four_step_inplace(
|
||||
}
|
||||
|
||||
// Get the kernel
|
||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
||||
auto to_type = [](const array& x) {
|
||||
return x.dtype() == float32 ? "float" : "float2";
|
||||
};
|
||||
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
||||
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
||||
std::string hash_name;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
|
Loading…
Reference in New Issue
Block a user