Fix fft for integer overflow (#2161)

This commit is contained in:
Awni Hannun 2025-05-09 14:25:12 -07:00 committed by GitHub
parent a7fae8a176
commit 6661387066
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 16 deletions

View File

@ -632,7 +632,7 @@ void fft_op(
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
size_t size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
@ -659,8 +659,6 @@ void fft_op(
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";

View File

@ -98,7 +98,7 @@ struct ReadWriter {
}
METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@ -121,7 +121,7 @@ struct ReadWriter {
}
METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n;
size_t batch_idx = size_t(elem.x * grid.y) * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
@ -144,7 +144,7 @@ struct ReadWriter {
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
@ -161,7 +161,7 @@ struct ReadWriter {
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -283,7 +283,8 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
@ -317,7 +318,7 @@ template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
size_t batch_idx =
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
@ -503,7 +505,7 @@ template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;