mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix fft for integer overflow (#2161)
This commit is contained in:
parent
a7fae8a176
commit
6661387066
@ -632,7 +632,7 @@ void fft_op(
|
|||||||
func_consts.push_back(make_int(&rader_m, 3));
|
func_consts.push_back(make_int(&rader_m, 3));
|
||||||
|
|
||||||
// The overall number of FFTs we're going to compute for this input
|
// The overall number of FFTs we're going to compute for this input
|
||||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
size_t size = out.dtype() == float32 ? out.size() : in.size();
|
||||||
if (real && inverse && four_step_params.required) {
|
if (real && inverse && four_step_params.required) {
|
||||||
size = out.size();
|
size = out.size();
|
||||||
}
|
}
|
||||||
@ -659,8 +659,6 @@ void fft_op(
|
|||||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||||
batch_size = (batch_size + 2 - 1) / 2;
|
batch_size = (batch_size + 2 - 1) / 2;
|
||||||
}
|
}
|
||||||
int out_buffer_size = out.size();
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||||
|
@ -98,7 +98,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void load() const {
|
METAL_FUNC void load() const {
|
||||||
int batch_idx = elem.x * grid.y * n;
|
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||||
short tg_idx = elem.y * grid.z + elem.z;
|
short tg_idx = elem.y * grid.z + elem.z;
|
||||||
short max_index = grid.y * n - 2;
|
short max_index = grid.y * n - 2;
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void write() const {
|
METAL_FUNC void write() const {
|
||||||
int batch_idx = elem.x * grid.y * n;
|
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||||
short tg_idx = elem.y * grid.z + elem.z;
|
short tg_idx = elem.y * grid.z + elem.z;
|
||||||
short max_index = grid.y * n - 2;
|
short max_index = grid.y * n - 2;
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ struct ReadWriter {
|
|||||||
|
|
||||||
// Padded IO for Bluestein's algorithm
|
// Padded IO for Bluestein's algorithm
|
||||||
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||||
int fft_idx = elem.z;
|
int fft_idx = elem.z;
|
||||||
int m = grid.z;
|
int m = grid.z;
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ struct ReadWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||||
int fft_idx = elem.z;
|
int fft_idx = elem.z;
|
||||||
int m = grid.z;
|
int m = grid.z;
|
||||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||||
@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
||||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@ -283,7 +283,8 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
||||||
short n_over_2 = (n / 2) + 1;
|
short n_over_2 = (n / 2) + 1;
|
||||||
|
|
||||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
size_t batch_idx =
|
||||||
|
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
@ -317,7 +318,7 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
|||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int length_over_2 = (length / 2) + 1;
|
int length_over_2 = (length / 2) + 1;
|
||||||
int batch_idx =
|
size_t batch_idx =
|
||||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
|||||||
template <>
|
template <>
|
||||||
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
||||||
short n_over_2 = (n / 2) + 1;
|
short n_over_2 = (n / 2) + 1;
|
||||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
size_t batch_idx =
|
||||||
|
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
|||||||
int n_over_2 = (n / 2) + 1;
|
int n_over_2 = (n / 2) + 1;
|
||||||
int length_over_2 = (length / 2) + 1;
|
int length_over_2 = (length / 2) + 1;
|
||||||
|
|
||||||
int batch_idx =
|
size_t batch_idx =
|
||||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||||
|
|
||||||
// No out of bounds accesses on odd batch sizes
|
// No out of bounds accesses on odd batch sizes
|
||||||
@ -503,7 +505,7 @@ template <>
|
|||||||
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
||||||
int length,
|
int length,
|
||||||
const device float2* w_k) const {
|
const device float2* w_k) const {
|
||||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||||
|
|
||||||
int grid_index = elem.x * grid.y + elem.y;
|
int grid_index = elem.x * grid.y + elem.y;
|
||||||
|
Loading…
Reference in New Issue
Block a user