mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 06:29:09 +08:00
Add gemm_grouped_conv
This commit is contained in:
parent
c81aeedec5
commit
849fee90f3
@ -17,6 +17,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -48,7 +49,32 @@ struct ConvParams {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
void gemm_conv(
|
void gemm_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
bool flip,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
inline void gemm_conv(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array in,
|
array in,
|
||||||
array wt,
|
array wt,
|
||||||
@ -59,6 +85,42 @@ void gemm_conv(
|
|||||||
const std::vector<int>& input_dilation,
|
const std::vector<int>& input_dilation,
|
||||||
int groups,
|
int groups,
|
||||||
bool flip,
|
bool flip,
|
||||||
Stream s);
|
Stream s) {
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groups == 1) {
|
||||||
|
gemm_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
gemm_grouped_conv(
|
||||||
|
encoder,
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/cuda/conv/conv.h"
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@ -137,21 +136,16 @@ void gemm_conv_nd(
|
|||||||
array& out,
|
array& out,
|
||||||
ConvParams<NDIM>& params,
|
ConvParams<NDIM>& params,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
if (params.groups > 1) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[conv] gemm_conv does not support grouped convolution yet.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get gemm shapes.
|
// Get gemm shapes.
|
||||||
int mat_M = out.size() / params.O; // N * H_out * W_out
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
||||||
int mat_N = params.O; // O
|
int mat_N = params.O; // O
|
||||||
|
|
||||||
// Unfold input for gemm.
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
array in_unfolded =
|
array in_unfolded =
|
||||||
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
// Reshape weight for gemm.
|
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
|
||||||
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
||||||
wt_reshaped.copy_shared_buffer(
|
wt_reshaped.copy_shared_buffer(
|
||||||
wt,
|
wt,
|
||||||
@ -191,14 +185,13 @@ void gemm_conv_nd(
|
|||||||
|
|
||||||
void gemm_conv(
|
void gemm_conv(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array in,
|
const array& in,
|
||||||
array wt,
|
const array& wt,
|
||||||
array& out,
|
array& out,
|
||||||
const std::vector<int>& strides,
|
const std::vector<int>& strides,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& kernel_dilation,
|
const std::vector<int>& kernel_dilation,
|
||||||
const std::vector<int>& input_dilation,
|
const std::vector<int>& input_dilation,
|
||||||
int groups,
|
|
||||||
bool flip,
|
bool flip,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
int conv_ndim = in.ndim() - 2;
|
int conv_ndim = in.ndim() - 2;
|
||||||
@ -206,16 +199,6 @@ void gemm_conv(
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!in.flags().row_contiguous) {
|
|
||||||
in = contiguous_copy_gpu(in, s);
|
|
||||||
encoder.add_temporary(in);
|
|
||||||
}
|
|
||||||
if (!wt.flags().row_contiguous) {
|
|
||||||
wt = contiguous_copy_gpu(wt, s);
|
|
||||||
encoder.add_temporary(wt);
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
ConvParams<ndim_constant()> params(
|
ConvParams<ndim_constant()> params(
|
||||||
in,
|
in,
|
||||||
@ -225,7 +208,7 @@ void gemm_conv(
|
|||||||
padding,
|
padding,
|
||||||
kernel_dilation,
|
kernel_dilation,
|
||||||
input_dilation,
|
input_dilation,
|
||||||
groups,
|
1, // groups
|
||||||
flip);
|
flip);
|
||||||
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
});
|
});
|
||||||
|
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/conv/conv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T, int NDIM>
|
||||||
|
__global__ void naive_grouped_unfold_transpose_nd(
|
||||||
|
const T* in,
|
||||||
|
T* out,
|
||||||
|
int filter_size,
|
||||||
|
int out_pixels,
|
||||||
|
const __grid_constant__ ConvParams<NDIM> params) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto tid = block.group_index();
|
||||||
|
auto lid = block.thread_index();
|
||||||
|
|
||||||
|
int index_batch = tid.z / out_pixels; // [0, N)
|
||||||
|
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||||
|
int index_wt_spatial =
|
||||||
|
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||||
|
|
||||||
|
if (index_wt_spatial >= filter_size / params.C) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
in += tid.y; // [0, C)
|
||||||
|
out += tid.z * filter_size + tid.y * (filter_size / params.C);
|
||||||
|
|
||||||
|
bool valid = index_batch < params.N;
|
||||||
|
|
||||||
|
// Get the coordinates in input.
|
||||||
|
int index_in[NDIM] = {};
|
||||||
|
int wt_stride = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
|
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||||
|
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||||
|
out += index_wt * wt_stride;
|
||||||
|
|
||||||
|
if (params.flip) {
|
||||||
|
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int index = index_out * params.strides[i] - params.padding[i] +
|
||||||
|
index_wt * params.kernel_dilation[i];
|
||||||
|
int index_max =
|
||||||
|
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||||
|
|
||||||
|
valid &= (index >= 0) && (index < index_max) &&
|
||||||
|
(index % params.input_dilation[i] == 0);
|
||||||
|
|
||||||
|
index_in[i] = index / params.input_dilation[i];
|
||||||
|
|
||||||
|
index_out_spatial /= params.out_spatial_dims[i];
|
||||||
|
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||||
|
wt_stride *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid) {
|
||||||
|
int in_offset = index_batch * params.in_strides[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||||
|
}
|
||||||
|
*out = in[in_offset];
|
||||||
|
} else {
|
||||||
|
*out = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
array grouped_unfold_transpose_inputs_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
int mat_M,
|
||||||
|
int mat_K,
|
||||||
|
int mat_N,
|
||||||
|
ConvParams<NDIM>& params) {
|
||||||
|
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||||
|
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||||
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
|
int filter_size = params.C;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
filter_size *= params.wt_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int out_pixels = 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
out_pixels *= params.out_spatial_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int wt_spatial_size = (mat_K * params.groups) / params.C;
|
||||||
|
dim3 block_dims;
|
||||||
|
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||||
|
dim3 num_blocks;
|
||||||
|
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||||
|
num_blocks.y = params.C;
|
||||||
|
num_blocks.z = mat_M;
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(unfolded);
|
||||||
|
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<DataType>(),
|
||||||
|
unfolded.data<DataType>(),
|
||||||
|
filter_size,
|
||||||
|
out_pixels,
|
||||||
|
params);
|
||||||
|
});
|
||||||
|
|
||||||
|
return unfolded;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
void gemm_grouped_conv_nd(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
ConvParams<NDIM>& params,
|
||||||
|
Stream s) {
|
||||||
|
// Get gemm shapes.
|
||||||
|
int C_per_group = params.C / params.groups;
|
||||||
|
int O_per_group = params.O / params.groups;
|
||||||
|
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||||
|
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
|
||||||
|
int mat_N = O_per_group; // O_per_group
|
||||||
|
|
||||||
|
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||||
|
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
|
||||||
|
encoder, in, mat_M, mat_K, mat_N, params);
|
||||||
|
|
||||||
|
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
|
||||||
|
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
|
||||||
|
array wt_view(
|
||||||
|
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
|
||||||
|
wt_view.copy_shared_buffer(
|
||||||
|
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||||
|
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
|
||||||
|
|
||||||
|
// Batch with size of groups.
|
||||||
|
Shape batch_shape{params.groups};
|
||||||
|
Strides a_batch_strides{mat_K};
|
||||||
|
Strides b_batch_strides{mat_N * mat_K};
|
||||||
|
|
||||||
|
// Run matmul.
|
||||||
|
CublasGemm gemm(
|
||||||
|
encoder.device(),
|
||||||
|
in.dtype(),
|
||||||
|
false, // a_transposed
|
||||||
|
mat_M, // a_rows
|
||||||
|
mat_K, // a_cols
|
||||||
|
mat_K * params.groups, // lda
|
||||||
|
true, // b_transposed
|
||||||
|
mat_K, // b_rows
|
||||||
|
mat_N, // b_cols
|
||||||
|
mat_K, // ldb
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
gemm.set_out(
|
||||||
|
out.dtype(),
|
||||||
|
false, // out_transposed
|
||||||
|
mat_M, // out_rows
|
||||||
|
mat_N, // out_cols
|
||||||
|
mat_N * params.groups, // out_ld
|
||||||
|
params.groups, // batch_count
|
||||||
|
mat_N); // batch_stride
|
||||||
|
gemm.run(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
in_unfolded,
|
||||||
|
wt_reshaped,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_grouped_conv(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& strides,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
Stream s) {
|
||||||
|
int conv_ndim = in.ndim() - 2;
|
||||||
|
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||||
|
}
|
||||||
|
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||||
|
ConvParams<ndim_constant()> params(
|
||||||
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
strides,
|
||||||
|
padding,
|
||||||
|
kernel_dilation,
|
||||||
|
input_dilation,
|
||||||
|
groups,
|
||||||
|
flip);
|
||||||
|
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CublasGemm::set_out(
|
||||||
|
Dtype dtype,
|
||||||
|
bool transposed,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
|
out_desc_ = create_matrix_layout(
|
||||||
|
dtype_to_cublas_type(dtype),
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
transposed,
|
||||||
|
ld,
|
||||||
|
batch_count,
|
||||||
|
batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
void CublasGemm::run(
|
void CublasGemm::run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
|
@ -44,6 +44,17 @@ class CublasGemm {
|
|||||||
|
|
||||||
~CublasGemm();
|
~CublasGemm();
|
||||||
|
|
||||||
|
// The output's descriptor is inferred from inputs by default, use this method
|
||||||
|
// for unusual output.
|
||||||
|
void set_out(
|
||||||
|
Dtype dtype,
|
||||||
|
bool transposed,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride);
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
|
@ -15,8 +15,6 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
# Convolutions NYI
|
# Convolutions NYI
|
||||||
"TestConv.test_1d_conv_with_2d",
|
"TestConv.test_1d_conv_with_2d",
|
||||||
"TestConv.test_conv_1d_groups_flipped",
|
|
||||||
"TestConv.test_torch_conv_depthwise",
|
|
||||||
# FFTs NYI
|
# FFTs NYI
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
|
Loading…
Reference in New Issue
Block a user