mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix four step fft
This commit is contained in:
parent
2a41caa00e
commit
83762691ba
@ -1117,8 +1117,11 @@ void fft_four_step_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
auto in_type = in.dtype() == float32 ? "float" : "float2";
|
auto to_type = [](const array& x) {
|
||||||
auto out_type = out.dtype() == float32 ? "float" : "float2";
|
return x.dtype() == float32 ? "float" : "float2";
|
||||||
|
};
|
||||||
|
auto in_type = step == 0 ? to_type(in) : to_type(intermediate);
|
||||||
|
auto out_type = step == 0 ? to_type(intermediate) : to_type(out);
|
||||||
std::string hash_name;
|
std::string hash_name;
|
||||||
std::string kname;
|
std::string kname;
|
||||||
kname.reserve(64);
|
kname.reserve(64);
|
||||||
|
Loading…
Reference in New Issue
Block a user