Files
TomoATT/include/simd_conf.h

194 lines
4.9 KiB
C
Raw Normal View History

2025-12-17 10:53:43 +08:00
#ifndef SIMD_CONF_H
#define SIMD_CONF_H
#ifdef USE_SIMD
// flags for self controling the SIMD usage
#define USE_AVX __AVX__
#define USE_AVX512 __AVX512F__
#define USE_ARM_SVE __ARM_FEATURE_SVE
#define USE_NEON __ARM_NEON__
// forcely make USE_AVX and USE_AVX512 to be false if USE_ARM_SVE is true
#if USE_ARM_SVE
#undef USE_AVX
#define USE_AVX 0
#undef USE_AVX512
#define USE_AVX512 0
#endif
#if USE_AVX || USE_AVX512
#include <immintrin.h>
#elif USE_ARM_SVE
#include <arm_sve.h>
#elif USE_NEON
#include <arm_neon.h>
#endif
#ifdef SINGLE_PRECISION
#if USE_AVX && !USE_AVX512
const int ALIGN = 32;
const int NSIMD = 8;
#define __mT __m256
#define _mmT_set1_pT _mm256_set1_ps
#define _mmT_loadu_pT _mm256_loadu_ps
#define _mmT_mul_pT _mm256_mul_ps
#define _mmT_sub_pT _mm256_sub_ps
#define _mmT_add_pT _mm256_add_ps
#define _mmT_div_pT _mm256_div_ps
#define _mmT_min_pT _mm256_min_ps
#define _mmT_sqrt_pT _mm256_sqrt_ps
#define _mmT_store_pT _mm256_store_ps
#define _mmT_fmadd_pT _mm256_fmadd_ps
#define _mm256_cmp_pT _mm256_cmp_ps
#elif USE_AVX512
const int ALIGN = 64;
const int NSIMD = 16;
#define __mT __m512
#define _mmT_set1_pT _mm512_set1_ps
#define _mmT_loadu_pT _mm512_loadu_ps
#define _mmT_mul_pT _mm512_mul_ps
#define _mmT_sub_pT _mm512_sub_ps
#define _mmT_add_pT _mm512_add_ps
#define _mmT_div_pT _mm512_div_ps
#define _mmT_min_pT _mm512_min_ps
#define _mmT_max_pT _mm512_max_ps
#define _mmT_sqrt_pT _mm512_sqrt_ps
#define _mmT_store_pT _mm512_store_ps
#define _mmT_fmadd_pT _mm512_fmadd_ps
#define __mmaskT __mmask16
#define _mm512_cmp_pT_mask _mm512_cmp_ps_mask
#define _mm512_mask_blend_pT _mm512_mask_blend_ps
#define _kand_maskT _kand_mask16
#elif USE_ARM_SVE // NOT TESTED YET
const int ALIGN = 8*svcntd(); // use svcntw() for float
const int NSIMD = svcntd(); // Vector Length
#define __mT svfloat32_t
//#define _mmT_set1_pT svdup_f64
//#define _mmT_loadu_pT svld1_f64_f64_f64
//#define _mmT_mul_pT svmul_f64_m
//#define _mmT_sub_pT svsub_f64_m
//#define _mmT_add_pT svadd_f64_m
//#define _mmT_div_pT svdiv_f64_m
//#define _mmT_min_pT svmin_f64_m
//#define _mmT_sqrt_pT svsqrt_f64_m
//#define _mmT_store_pT svst1_f64
#elif USE_NEON
const int ALIGN = 16;
const int NSIMD = 4;
#define __mT float32x4_t
#define _mmT_set1_pT vdupq_n_f32
#define _mmT_loadu_pT vld1q_f32
#define _mmT_mul_pT vmulq_f32
#define _mmT_sub_pT vsubq_f32
#define _mmT_add_pT vaddq_f32
#define _mmT_div_pT vdivq_f32
#define _mmT_min_pT vminq_f32
#define _mmT_sqrt_pT vsqrtq_f32
#define _mmT_store_pT vst1q_f32
#endif // __ARM_FEATURE_SVE
#else // DOUBLE_PRECISION
#if USE_AVX && !USE_AVX512
const int ALIGN = 32;
const int NSIMD = 4;
#define __mT __m256d
#define _mmT_set1_pT _mm256_set1_pd
#define _mmT_loadu_pT _mm256_loadu_pd
#define _mmT_mul_pT _mm256_mul_pd
#define _mmT_sub_pT _mm256_sub_pd
#define _mmT_add_pT _mm256_add_pd
#define _mmT_div_pT _mm256_div_pd
#define _mmT_min_pT _mm256_min_pd
#define _mmT_sqrt_pT _mm256_sqrt_pd
#define _mmT_store_pT _mm256_store_pd
#define _mmT_fmadd_pT _mm256_fmadd_pd
#define _mm256_cmp_pT _mm256_cmp_pd
#elif USE_AVX512
const int ALIGN = 64;
const int NSIMD = 8;
#define __mT __m512d
#define _mmT_set1_pT _mm512_set1_pd
#define _mmT_loadu_pT _mm512_loadu_pd
#define _mmT_mul_pT _mm512_mul_pd
#define _mmT_sub_pT _mm512_sub_pd
#define _mmT_add_pT _mm512_add_pd
#define _mmT_div_pT _mm512_div_pd
#define _mmT_min_pT _mm512_min_pd
#define _mmT_max_pT _mm512_max_pd
#define _mmT_sqrt_pT _mm512_sqrt_pd
#define _mmT_store_pT _mm512_store_pd
#define _mmT_fmadd_pT _mm512_fmadd_pd
#define __mmaskT __mmask8
#define _mm512_cmp_pT_mask _mm512_cmp_pd_mask
#define _mm512_mask_blend_pT _mm512_mask_blend_pd
#define _kand_maskT _kand_mask8
#elif USE_ARM_SVE
const int ALIGN = 8*svcntd(); // use svcntw() for float
const int NSIMD = svcntd(); // Vector Length
#define __mT svfloat64_t
//#define _mmT_set1_pT svdup_f64
//#define _mmT_loadu_pT svld1_f64_f64_f64
//#define _mmT_mul_pT svmul_f64_m
//#define _mmT_sub_pT svsub_f64_m
//#define _mmT_add_pT svadd_f64_m
//#define _mmT_div_pT svdiv_f64_m
//#define _mmT_min_pT svmin_f64_m
//#define _mmT_sqrt_pT svsqrt_f64_m
//#define _mmT_store_pT svst1_f64
#elif USE_NEON
const int ALIGN = 16;
const int NSIMD = 2;
#define __mT float64x2_t
#define _mmT_set1_pT vdupq_n_f64
#define _mmT_loadu_pT vld1q_f64
#define _mmT_mul_pT vmulq_f64
#define _mmT_sub_pT vsubq_f64
#define _mmT_add_pT vaddq_f64
#define _mmT_div_pT vdivq_f64
#define _mmT_min_pT vminq_f64
#define _mmT_sqrt_pT vsqrtq_f64
#define _mmT_store_pT vst1q_f64
#endif // __ARM_FEATURE_SVE
#endif // DOUBLE_PRECISION
inline void print_simd_type() {
std::cout << "SIMD type: ";
#if USE_AVX && !USE_AVX512
std::cout << "AVX" << std::endl;
#elif USE_AVX512
std::cout << "AVX512" << std::endl;
#elif USE_ARM_SVE
std::cout << "ARM SVE" << std::endl;
#elif USE_NEON
std::cout << "NEON" << std::endl;
#endif // __ARM_FEATURE_SVE
}
#else // USE_CUDA but not AVX dummy
const int ALIGN = 32;
const int NSIMD = 4;
#endif // USE_SIMD
#endif // SIMD_CONF_H