mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix fft for integer overflow (#2161)
This commit is contained in:
parent
a7fae8a176
commit
6661387066
@ -632,7 +632,7 @@ void fft_op(
|
||||
func_consts.push_back(make_int(&rader_m, 3));
|
||||
|
||||
// The overall number of FFTs we're going to compute for this input
|
||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
||||
size_t size = out.dtype() == float32 ? out.size() : in.size();
|
||||
if (real && inverse && four_step_params.required) {
|
||||
size = out.size();
|
||||
}
|
||||
@ -659,8 +659,6 @@ void fft_op(
|
||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||
batch_size = (batch_size + 2 - 1) / 2;
|
||||
}
|
||||
int out_buffer_size = out.size();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||
|
@ -98,7 +98,7 @@ struct ReadWriter {
|
||||
}
|
||||
|
||||
METAL_FUNC void load() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
@ -121,7 +121,7 @@ struct ReadWriter {
|
||||
}
|
||||
|
||||
METAL_FUNC void write() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
@ -144,7 +144,7 @@ struct ReadWriter {
|
||||
|
||||
// Padded IO for Bluestein's algorithm
|
||||
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
|
||||
@ -161,7 +161,7 @@ struct ReadWriter {
|
||||
}
|
||||
|
||||
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
@ -283,7 +283,8 @@ template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
size_t batch_idx =
|
||||
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
@ -317,7 +318,7 @@ template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
size_t batch_idx =
|
||||
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
size_t batch_idx =
|
||||
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
||||
int n_over_2 = (n / 2) + 1;
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
size_t batch_idx =
|
||||
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
@ -503,7 +505,7 @@ template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
|
Loading…
Reference in New Issue
Block a user