#ifndef VECTORIZED_SWEEP_H #define VECTORIZED_SWEEP_H #include #include "config.h" #ifdef USE_SIMD // closed at the end of this file #include "simd_conf.h" #if USE_AVX || USE_AVX512 __mT COEF = _mmT_set1_pT(1.0); __mT v_1 = _mmT_set1_pT(1.0); __mT v_0 = _mmT_set1_pT(0.0); __mT v_half = _mmT_set1_pT(0.5); __mT v_2 = _mmT_set1_pT(2.0); __mT v_m2 = _mmT_set1_pT(-2.0); __mT v_4 = _mmT_set1_pT(4.0); __mT v_m3 = _mmT_set1_pT(-3.0); __mT coe_max = _mmT_set1_pT(1e19); __mT v_eps = _mmT_set1_pT(1e-12); __mT COEF_TELE = _mmT_set1_pT(1.05); // square of __mT inline __mT my_square_v(__mT const& a){ return _mmT_mul_pT(a, a); } // pp = (a - b) * Dinv inline __mT calc_1d_stencil(__mT const& a, __mT const& b, __mT const& Dinv){ return _mmT_mul_pT(_mmT_sub_pT(a,b),Dinv); } /* ww = _1_CR/(_1_CR+_2_CR*my_square_v((eps + my_square_v( a \ -_2_CR*b \ + c) ) \ / (eps + my_square_v( d \ -_2_CR*a \ + b) )) ); pp = sign * (_1_CR - ww) * ( b \ <-- sign = +1 or -1 - d) * 0.5 * Dinv \ + ww * ( -3.0 * a \ + 4.0 * b \ - 1.0 * c ) * 0.5 * Dinv ); */ inline __mT calc_3d_stencil(__mT const& a, __mT const& b, __mT const&c, __mT const& d, __mT const& Dinv_half, int const& sign){ #ifdef __FMA__ // v_eps + square(a - 2.0*b + c) __mT tmp1 = _mmT_add_pT(v_eps,my_square_v(_mmT_add_pT(a,_mmT_fmadd_pT(v_m2,b,c)))); // v_eps + square(d - 2.0*a + b) __mT tmp2 = _mmT_add_pT(v_eps,my_square_v(_mmT_add_pT(d,_mmT_fmadd_pT(v_m2,a,b)))); // ww = 1.0/(1.0 + 2.0 * square(tmp1/tmp2)) __mT ww = _mmT_div_pT(v_1,_mmT_fmadd_pT(v_2,my_square_v(_mmT_div_pT(tmp1,tmp2)),v_1)); // pp = sign* ((1.0 - ww) * (b - d) / 2.0 / D // + ww * (-3.0* a + 4.0 * b - c) * Dinv_half) /* */ return _mmT_mul_pT(_mmT_set1_pT(sign), \ _mmT_add_pT(\ _mmT_mul_pT(_mmT_sub_pT(v_1,ww),_mmT_mul_pT(_mmT_sub_pT(b,d),Dinv_half)),\ _mmT_mul_pT(ww,_mmT_mul_pT(_mmT_sub_pT(_mmT_fmadd_pT(v_4,b,_mmT_mul_pT(v_m3,a)),c),Dinv_half))\ )\ ); #else __mT tmp1 = _mmT_add_pT(v_eps,my_square_v(_mmT_add_pT(a,_mmT_add_pT(_mmT_mul_pT(v_m2,b),c)))); __mT tmp2 = _mmT_add_pT(v_eps,my_square_v(_mmT_add_pT(d,_mmT_add_pT(_mmT_mul_pT(v_m2,a),b)))); __mT ww = _mmT_div_pT(v_1,_mmT_add_pT(v_1,_mmT_mul_pT(v_2,my_square_v(_mmT_div_pT(tmp1,tmp2))))); return _mmT_mul_pT(_mmT_set1_pT(sign), \ _mmT_add_pT(\ _mmT_mul_pT(_mmT_sub_pT(v_1,ww),_mmT_mul_pT(_mmT_sub_pT(b,d),Dinv_half)),\ _mmT_mul_pT(ww,_mmT_mul_pT(_mmT_sub_pT(_mmT_add_pT(_mmT_mul_pT(v_4,b),_mmT_mul_pT(v_m3,a)),c),Dinv_half))\ )\ ); #endif } #elif USE_ARM_SVE inline __mT my_square_v(svbool_t const& pg, __mT const& a){ return svmul_f64_z(pg, a, a); } inline __mT calc_1d_stencil(svbool_t const& pg, __mT const& a, __mT const& b, __mT const& Dinv){ return svmul_f64_z(pg, svsub_f64_z(pg, a, b), Dinv); } inline __mT calc_3d_stencil(svbool_t const& pg, __mT const& a, __mT const& b, __mT const&c, __mT const& d, __mT const& Dinv_half, int const& sign){ __mT v_1 = svdup_f64(1.0); __mT v_2 = svdup_f64(2.0); __mT v_m2 = svdup_f64(-2.0); __mT v_4 = svdup_f64(4.0); __mT v_m3 = svdup_f64(-3.0); __mT v_eps = svdup_f64(1e-12); // v_eps + square(a - 2.0*b + c) //__mT tmp1 = svadd_f64_z(pg, v_eps, my_square_v(pg, svadd_f64_z(pg, a, svadd_f64_z(pg, svmul_f64_z(pg, v_m2, b), c)))); __mT tmp1 = svadd_f64_z(pg, v_eps, my_square_v(pg, svadd_f64_z(pg, a, svmad_f64_z(pg, v_m2, b, c)))); // v_eps + square(d - 2.0*a + b) //__mT tmp2 = svadd_f64_z(pg, v_eps, my_square_v(pg, svadd_f64_z(pg, d, svadd_f64_z(pg, svmul_f64_z(pg, v_m2, a), b)))); __mT tmp2 = svadd_f64_z(pg, v_eps, my_square_v(pg, svadd_f64_z(pg, d, svmad_f64_z(pg, v_m2, a, b)))); // ww = 1.0/(1.0 + 2.0 * square(tmp1/tmp2)) //__mT ww = svdiv_f64_z(pg, v_1, svadd_f64_z(pg, v_1, svmul_f64_z(pg, v_2, my_square_v(pg, svdiv_f64_z(pg, tmp1, tmp2))))); __mT ww = svdiv_f64_z(pg, v_1, svmad_f64_z(pg, v_2, my_square_v(pg, svdiv_f64_z(pg, tmp1, tmp2)), v_1)); // pp = sign* ((1.0 - ww) * (b - d) / 2.0 / D // + ww * (-3.0* a + 4.0 * b - c) * Dinv_half) /* return svmul_f64_m(pg, svdup_f64(sign), \ svadd_f64_z(pg, \ svmul_f64_z(pg, svsub_f64_z(pg, v_1, ww), svmul_f64_z(pg, svsub_f64_z(pg, b, d), Dinv_half)),\ svmul_f64_z(pg, ww, svmul_f64_z(pg, svsub_f64_z(pg, svadd_f64_z(pg, svmul_f64_z(pg, v_4, b), svmul_f64_z(pg, v_m3, a)), c), Dinv_half))\ )\ ); */ return svmul_f64_m(pg, svdup_f64(sign), \ svadd_f64_z(pg, \ svmul_f64_z(pg, svsub_f64_z(pg, v_1, ww), svmul_f64_z(pg, svsub_f64_z(pg, b, d), Dinv_half)),\ svmul_f64_z(pg, ww, svmul_f64_z(pg, svsub_f64_z(pg, svmad_f64_z(pg, v_4, b, svmul_f64_z(pg, v_m3, a)), c), Dinv_half))\ )\ ); } #endif inline void vect_stencil_1st_pre_simd( #if USE_ARM_SVE svbool_t const& pg, #endif __mT const& v_iip, __mT const& v_jjt, __mT const& v_kkr, __mT const& v_c__, __mT const& v_p__, __mT const& v_m__, __mT const& v__p_, __mT const& v__m_, __mT const& v___p, __mT const& v___m, __mT& v_pp1, __mT& v_pp2, __mT& v_pt1, __mT& v_pt2, __mT& v_pr1, __mT& v_pr2, __mT const& v_DP_inv, __mT const& v_DT_inv, __mT const& v_DR_inv, __mT const& v_DP_inv_half, __mT const& v_DT_inv_half, __mT const& v_DR_inv_half, int const& NP, int const& NT, int const& NR){ #if USE_AVX512 || USE_AVX v_pp1 = calc_1d_stencil(v_c__, v_m__, v_DP_inv); v_pp2 = calc_1d_stencil(v_p__, v_c__, v_DP_inv); v_pt1 = calc_1d_stencil(v_c__, v__m_, v_DT_inv); v_pt2 = calc_1d_stencil(v__p_, v_c__, v_DT_inv); v_pr1 = calc_1d_stencil(v_c__, v___m, v_DR_inv); v_pr2 = calc_1d_stencil(v___p, v_c__, v_DR_inv); #elif USE_ARM_SVE v_pp1 = calc_1d_stencil(pg, v_c__, v_m__, v_DP_inv); v_pp2 = calc_1d_stencil(pg, v_p__, v_c__, v_DP_inv); v_pt1 = calc_1d_stencil(pg, v_c__, v__m_, v_DT_inv); v_pt2 = calc_1d_stencil(pg, v__p_, v_c__, v_DT_inv); v_pr1 = calc_1d_stencil(pg, v_c__, v___m, v_DR_inv); v_pr2 = calc_1d_stencil(pg, v___p, v_c__, v_DR_inv); #endif } inline void vect_stencil_3rd_pre_simd( #if USE_ARM_SVE svbool_t const& pg, #endif __mT const& v_iip, __mT const& v_jjt, __mT const& v_kkr, __mT const& v_c__, __mT const& v_p__, __mT const& v_m__, __mT const& v__p_, __mT const& v__m_, __mT const& v___p, __mT const& v___m, __mT const& v_pp____, __mT const& v_mm____, __mT const& v___pp__, __mT const& v___mm__, __mT const& v_____pp, __mT const& v_____mm, __mT& v_pp1, __mT& v_pp2, __mT& v_pt1, __mT& v_pt2, __mT& v_pr1, __mT& v_pr2, __mT const& v_DP_inv, __mT const& v_DT_inv, __mT const& v_DR_inv, __mT const& v_DP_inv_half, __mT const& v_DT_inv_half, __mT const& v_DR_inv_half, int const& NP, int const& NT, int const& NR){ const int PLUS = 1; const int MINUS = -1; #if USE_AVX512 __mmaskT mask_i_eq_1 = _mm512_cmp_pT_mask(v_iip, v_1,_CMP_EQ_OQ); // if iip == 1 __mmaskT mask_j_eq_1 = _mm512_cmp_pT_mask(v_jjt, v_1,_CMP_EQ_OQ); // if jjt == 1 __mmaskT mask_k_eq_1 = _mm512_cmp_pT_mask(v_kkr, v_1,_CMP_EQ_OQ); // if kkr == 1 __mmaskT mask_i_eq_N_minus_2 = _mm512_cmp_pT_mask(v_iip, _mmT_set1_pT(NP-2),_CMP_EQ_OQ); // if iip == N-2 __mmaskT mask_j_eq_N_minus_2 = _mm512_cmp_pT_mask(v_jjt, _mmT_set1_pT(NT-2),_CMP_EQ_OQ); // if jjt == N-2 __mmaskT mask_k_eq_N_minus_2 = _mm512_cmp_pT_mask(v_kkr, _mmT_set1_pT(NR-2),_CMP_EQ_OQ); // if kkr == N-2 // 1 < iip < N-2 __mmaskT mask_i_else = _kand_maskT( _mm512_cmp_pT_mask(v_iip, v_1 , _CMP_GT_OQ), _mm512_cmp_pT_mask(_mmT_set1_pT(NP-2), v_iip, _CMP_GT_OQ) ); // 1 < jjt < N-2 __mmaskT mask_j_else = _kand_maskT( _mm512_cmp_pT_mask(v_jjt, v_1 , _CMP_GT_OQ), _mm512_cmp_pT_mask(_mmT_set1_pT(NT-2), v_jjt, _CMP_GT_OQ) ); // 1 < kkr < N-2 __mmaskT mask_k_else = _kand_maskT( _mm512_cmp_pT_mask(v_kkr, v_1 , _CMP_GT_OQ), _mm512_cmp_pT_mask(_mmT_set1_pT(NR-2), v_kkr, _CMP_GT_OQ) ); // if _i_eq_1 == true v_pp1 = _mm512_mask_blend_pT(mask_i_eq_1, calc_1d_stencil(v_c__, v_m__, v_DP_inv) , v_pp1 ); v_pp2 = _mm512_mask_blend_pT(mask_i_eq_1, calc_3d_stencil(v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS), v_pp2 ); // if_i_eq_N_minus_2 == true v_pp1 = _mm512_mask_blend_pT(mask_i_eq_N_minus_2, calc_3d_stencil(v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), v_pp1); v_pp2 = _mm512_mask_blend_pT(mask_i_eq_N_minus_2, calc_1d_stencil(v_p__, v_c__, v_DP_inv) , v_pp2); // else v_pp1 = _mm512_mask_blend_pT(mask_i_else, calc_3d_stencil(v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), v_pp1); v_pp2 = _mm512_mask_blend_pT(mask_i_else, calc_3d_stencil(v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS) , v_pp2); // if _j_eq_1 == true v_pt1 = _mm512_mask_blend_pT(mask_j_eq_1, calc_1d_stencil(v_c__, v__m_, v_DT_inv) , v_pt1); v_pt2 = _mm512_mask_blend_pT(mask_j_eq_1, calc_3d_stencil(v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS), v_pt2); // if _j_eq_N_minus_2 == true v_pt1 = _mm512_mask_blend_pT(mask_j_eq_N_minus_2, calc_3d_stencil(v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), v_pt1); v_pt2 = _mm512_mask_blend_pT(mask_j_eq_N_minus_2, calc_1d_stencil(v__p_, v_c__, v_DT_inv) , v_pt2); // else v_pt1 = _mm512_mask_blend_pT(mask_j_else, calc_3d_stencil(v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), v_pt1); v_pt2 = _mm512_mask_blend_pT(mask_j_else, calc_3d_stencil(v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS) , v_pt2); // if _k_eq_1 == true v_pr1 = _mm512_mask_blend_pT(mask_k_eq_1, calc_1d_stencil(v_c__, v___m, v_DR_inv) , v_pr1 ); v_pr2 = _mm512_mask_blend_pT(mask_k_eq_1, calc_3d_stencil(v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS), v_pr2 ); // if _k_eq_N_minus_2 == true v_pr1 = _mm512_mask_blend_pT(mask_k_eq_N_minus_2, calc_3d_stencil(v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), v_pr1); v_pr2 = _mm512_mask_blend_pT(mask_k_eq_N_minus_2, calc_1d_stencil(v___p, v_c__, v_DR_inv) , v_pr2); // else v_pr1 = _mm512_mask_blend_pT(mask_k_else, calc_3d_stencil(v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), v_pr1); v_pr2 = _mm512_mask_blend_pT(mask_k_else, calc_3d_stencil(v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS) , v_pr2); #elif USE_AVX __mT mask_i_eq_1 = _mm256_cmp_pT(v_iip, v_1,_CMP_EQ_OQ); // if iip == 1 __mT mask_j_eq_1 = _mm256_cmp_pT(v_jjt, v_1,_CMP_EQ_OQ); // if jjt == 1 __mT mask_k_eq_1 = _mm256_cmp_pT(v_kkr, v_1,_CMP_EQ_OQ); // if kkr == 1 __mT mask_i_eq_N_minus_2 = _mm256_cmp_pT(v_iip, _mmT_set1_pT(NP-2),_CMP_EQ_OQ); // if iip == N-2 __mT mask_j_eq_N_minus_2 = _mm256_cmp_pT(v_jjt, _mmT_set1_pT(NT-2),_CMP_EQ_OQ); // if jjt == N-2 __mT mask_k_eq_N_minus_2 = _mm256_cmp_pT(v_kkr, _mmT_set1_pT(NR-2),_CMP_EQ_OQ); // if kkr == N-2 // 1 < iip < N-2 __m256d mask_i_else = _mm256_and_pd( _mm256_cmp_pT(v_iip, v_1 , _CMP_GT_OQ), _mm256_cmp_pT(_mmT_set1_pT(NP-2),v_iip, _CMP_GT_OQ) ); // 1 < jjt < N-2 __m256d mask_j_else = _mm256_and_pd( _mm256_cmp_pT(v_jjt, v_1 , _CMP_GT_OQ), _mm256_cmp_pT(_mmT_set1_pT(NT-2),v_jjt, _CMP_GT_OQ) ); // 1 < kkr < N-2 __m256d mask_k_else = _mm256_and_pd( _mm256_cmp_pT(v_kkr, v_1 , _CMP_GT_OQ), _mm256_cmp_pT(_mmT_set1_pT(NR-2),v_kkr, _CMP_GT_OQ) ); // if _i_eq_1 == true v_pp1 = _mm256_blendv_pd(v_pp1, calc_1d_stencil(v_c__, v_m__, v_DP_inv), mask_i_eq_1); v_pp2 = _mm256_blendv_pd(v_pp2, calc_3d_stencil(v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS), mask_i_eq_1); // if_i_eq_N_minus_2 == true v_pp1 = _mm256_blendv_pd(v_pp1, calc_3d_stencil(v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), mask_i_eq_N_minus_2); v_pp2 = _mm256_blendv_pd(v_pp2, calc_1d_stencil(v_p__, v_c__, v_DP_inv), mask_i_eq_N_minus_2); // else v_pp1 = _mm256_blendv_pd(v_pp1, calc_3d_stencil(v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), mask_i_else); v_pp2 = _mm256_blendv_pd(v_pp2, calc_3d_stencil(v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS), mask_i_else); // if _j_eq_1 == true v_pt1 = _mm256_blendv_pd(v_pt1, calc_1d_stencil(v_c__, v__m_, v_DT_inv), mask_j_eq_1); v_pt2 = _mm256_blendv_pd(v_pt2, calc_3d_stencil(v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS), mask_j_eq_1); // if _j_eq_N_minus_2 == true v_pt1 = _mm256_blendv_pd(v_pt1, calc_3d_stencil(v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), mask_j_eq_N_minus_2); v_pt2 = _mm256_blendv_pd(v_pt2, calc_1d_stencil(v__p_, v_c__, v_DT_inv), mask_j_eq_N_minus_2); // else v_pt1 = _mm256_blendv_pd(v_pt1, calc_3d_stencil(v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), mask_j_else); v_pt2 = _mm256_blendv_pd(v_pt2, calc_3d_stencil(v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS), mask_j_else); // if _k_eq_1 == true v_pr1 = _mm256_blendv_pd(v_pr1, calc_1d_stencil(v_c__, v___m, v_DR_inv), mask_k_eq_1); v_pr2 = _mm256_blendv_pd(v_pr2, calc_3d_stencil(v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS), mask_k_eq_1); // if _k_eq_N_minus_2 == true v_pr1 = _mm256_blendv_pd(v_pr1, calc_3d_stencil(v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), mask_k_eq_N_minus_2); v_pr2 = _mm256_blendv_pd(v_pr2, calc_1d_stencil(v___p, v_c__, v_DR_inv), mask_k_eq_N_minus_2); // else v_pr1 = _mm256_blendv_pd(v_pr1, calc_3d_stencil(v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), mask_k_else); v_pr2 = _mm256_blendv_pd(v_pr2, calc_3d_stencil(v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS), mask_k_else); #elif USE_ARM_SVE svfloat64_t v_1 = svdup_f64(1.0); svfloat64_t v_NP_minus_2 = svdup_f64(NP-2); svfloat64_t v_NT_minus_2 = svdup_f64(NT-2); svfloat64_t v_NR_minus_2 = svdup_f64(NR-2); svbool_t mask_i_eq_1 = svcmpeq_f64(pg, v_iip, v_1); // if iip == 1 svbool_t mask_j_eq_1 = svcmpeq_f64(pg, v_jjt, v_1); // if jjt == 1 svbool_t mask_k_eq_1 = svcmpeq_f64(pg, v_kkr, v_1); // if kkr == 1 svbool_t mask_i_eq_N_minus_2 = svcmpeq_f64(pg, v_iip, v_NP_minus_2); // if iip == N-2 svbool_t mask_j_eq_N_minus_2 = svcmpeq_f64(pg, v_jjt, v_NT_minus_2); // if jjt == N-2 svbool_t mask_k_eq_N_minus_2 = svcmpeq_f64(pg, v_kkr, v_NR_minus_2); // if kkr == N-2 // 1 < iip < N-2 svbool_t mask_i_else = svand_b_z(pg, svcmpgt_f64(pg, v_iip, v_1), svcmplt_f64(pg, v_iip, v_NP_minus_2) ); // 1 < jjt < N-2 svbool_t mask_j_else = svand_b_z(pg, svcmpgt_f64(pg, v_jjt, v_1), svcmplt_f64(pg, v_jjt, v_NT_minus_2) ); // 1 < kkr < N-2 svbool_t mask_k_else = svand_b_z(pg, svcmpgt_f64(pg, v_kkr, v_1), svcmplt_f64(pg, v_kkr, v_NR_minus_2) ); // if _i_eq_1 == true v_pp1 = svsel_f64(mask_i_eq_1, calc_1d_stencil(mask_i_eq_1, v_c__, v_m__, v_DP_inv) , v_pp1); v_pp2 = svsel_f64(mask_i_eq_1, calc_3d_stencil(mask_i_eq_1, v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS), v_pp2); // if_i_eq_N_minus_2 == true v_pp1 = svsel_f64(mask_i_eq_N_minus_2, calc_3d_stencil(mask_i_eq_N_minus_2, v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), v_pp1); v_pp2 = svsel_f64(mask_i_eq_N_minus_2, calc_1d_stencil(mask_i_eq_N_minus_2, v_p__, v_c__, v_DP_inv) , v_pp2); // else v_pp1 = svsel_f64(mask_i_else, calc_3d_stencil(mask_i_else, v_c__, v_m__, v_mm____, v_p__, v_DP_inv_half, MINUS), v_pp1); v_pp2 = svsel_f64(mask_i_else, calc_3d_stencil(mask_i_else, v_c__, v_p__, v_pp____, v_m__, v_DP_inv_half, PLUS) , v_pp2); // if _j_eq_1 == true v_pt1 = svsel_f64(mask_j_eq_1, calc_1d_stencil(mask_j_eq_1, v_c__, v__m_, v_DT_inv) , v_pt1); v_pt2 = svsel_f64(mask_j_eq_1, calc_3d_stencil(mask_j_eq_1, v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS), v_pt2); // if _j_eq_N_minus_2 == true v_pt1 = svsel_f64(mask_j_eq_N_minus_2, calc_3d_stencil(mask_j_eq_N_minus_2, v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), v_pt1); v_pt2 = svsel_f64(mask_j_eq_N_minus_2, calc_1d_stencil(mask_j_eq_N_minus_2, v__p_, v_c__, v_DT_inv) , v_pt2); // else v_pt1 = svsel_f64(mask_j_else, calc_3d_stencil(mask_j_else, v_c__, v__m_, v___mm__, v__p_, v_DT_inv_half, MINUS), v_pt1); v_pt2 = svsel_f64(mask_j_else, calc_3d_stencil(mask_j_else, v_c__, v__p_, v___pp__, v__m_, v_DT_inv_half, PLUS) , v_pt2); // if _k_eq_1 == true v_pr1 = svsel_f64(mask_k_eq_1, calc_1d_stencil(mask_k_eq_1, v_c__, v___m, v_DR_inv) , v_pr1); v_pr2 = svsel_f64(mask_k_eq_1, calc_3d_stencil(mask_k_eq_1, v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS), v_pr2); // if _k_eq_N_minus_2 == true v_pr1 = svsel_f64(mask_k_eq_N_minus_2, calc_3d_stencil(mask_k_eq_N_minus_2, v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), v_pr1); v_pr2 = svsel_f64(mask_k_eq_N_minus_2, calc_1d_stencil(mask_k_eq_N_minus_2, v___p, v_c__, v_DR_inv) , v_pr2); // else v_pr1 = svsel_f64(mask_k_else, calc_3d_stencil(mask_k_else, v_c__, v___m, v_____mm, v___p, v_DR_inv_half, MINUS), v_pr1); v_pr2 = svsel_f64(mask_k_else, calc_3d_stencil(mask_k_else, v_c__, v___p, v_____pp, v___m, v_DR_inv_half, PLUS) , v_pr2); #endif } inline void vect_stencil_1st_3rd_apre_simd( #if USE_ARM_SVE svbool_t const& pg, #endif __mT& v_tau, __mT const& v_fac_a,__mT const& v_fac_b, __mT const& v_fac_c, __mT const& v_fac_f, __mT const& v_T0v, __mT const& v_T0p, __mT const& v_T0t, __mT const& v_T0r, __mT const& v_fun, __mT const& v_change, __mT const& v_pp1, __mT const& v_pp2, __mT const& v_pt1, __mT const& v_pt2, __mT const& v_pr1, __mT const& v_pr2, __mT const& DP_inv, __mT const& DT_inv, __mT const& DR_inv){ #if USE_AVX512 || USE_AVX // sigr = COEF * sqrt(v_fac_a)*v_T0v; __mT sigr = _mmT_mul_pT(_mmT_mul_pT(COEF,_mmT_sqrt_pT(v_fac_a)),v_T0v); // sigt = COEF * sqrt(v_fac_b)*v_T0v; __mT sigt = _mmT_mul_pT(_mmT_mul_pT(COEF,_mmT_sqrt_pT(v_fac_b)),v_T0v); // sigp = COEF * sqrt(v_fac_c)*v_T0v; __mT sigp = _mmT_mul_pT(_mmT_mul_pT(COEF,_mmT_sqrt_pT(v_fac_c)),v_T0v); // coe = 1.0 / (sigr/D + sigt/D + sigz/D); __mT coe = _mmT_div_pT(v_1,_mmT_add_pT(_mmT_add_pT(_mmT_mul_pT(sigr,DR_inv),_mmT_mul_pT(sigt,DT_inv)),_mmT_mul_pT(sigp,DP_inv))); // coe becomes inf as sig* goes to 0 // if coe > 1e19, set coe = 1e19 coe = _mmT_min_pT(coe,coe_max); // Htau = v_fac_a * square(v_T0r * v_tau + v_T0v * (v_pr1 + v_pr2)*0.5); __mT Htau = _mmT_mul_pT(v_fac_a,my_square_v(_mmT_add_pT(_mmT_mul_pT(v_T0r,v_tau),_mmT_mul_pT(v_T0v,_mmT_mul_pT(_mmT_add_pT(v_pr1,v_pr2),v_half))))); // Htau += v_fac_b * square(v_T0t * v_tau + v_T0v * (v_pt1 + v_pt2)*0.5)) Htau = _mmT_add_pT(Htau,_mmT_mul_pT(v_fac_b,my_square_v(_mmT_add_pT(_mmT_mul_pT(v_T0t,v_tau),_mmT_mul_pT(v_T0v,_mmT_mul_pT(_mmT_add_pT(v_pt1,v_pt2), v_half)))))); // Htau += v_fac_c * square(v_T0p * v_tau + v_T0v * (v_pp1 + v_pp2)*0.5)) Htau = _mmT_add_pT(Htau,_mmT_mul_pT(v_fac_c,my_square_v(_mmT_add_pT(_mmT_mul_pT(v_T0p,v_tau),_mmT_mul_pT(v_T0v,_mmT_mul_pT(_mmT_add_pT(v_pp1,v_pp2), v_half)))))); // tmp1 = ( v_T0t * v_tau + v_T0v * (v_pt1 + v_pt2)*0.5) __mT tmp1 = _mmT_add_pT(_mmT_mul_pT(v_T0t,v_tau),_mmT_mul_pT(v_T0v,_mmT_mul_pT(_mmT_add_pT(v_pt1,v_pt2), v_half))); // tmp2 = ( v_T0p * v_tau + v_T0v * (v_pp1 + v_pp2)*0.5)) __mT tmp2 = _mmT_add_pT(_mmT_mul_pT(v_T0p,v_tau),_mmT_mul_pT(v_T0v,_mmT_mul_pT(_mmT_add_pT(v_pp1,v_pp2), v_half))); // tmp3 = -2.0 * v_fac_f * tmp1 * tmp2; __mT tmp3 = _mmT_mul_pT(v_m2,_mmT_mul_pT(v_fac_f,_mmT_mul_pT(tmp1,tmp2))); // Htau = sqrt(Htau + tmp3); // # becamse nan as Htau - tmp3 goes to 0 Htau = _mmT_sqrt_pT(_mmT_add_pT(Htau,tmp3)); // tmp = (sigr*(v_pr2 - v_pr1) + sigt*(v_pt2 - v_pt1) + sigz*(v_pp2 - v_pp1))*0.5 __mT tmp = _mmT_mul_pT(v_half,_mmT_add_pT(_mmT_add_pT(_mmT_mul_pT(sigr,_mmT_sub_pT(v_pr2,v_pr1)),_mmT_mul_pT(sigt,_mmT_sub_pT(v_pt2,v_pt1))),_mmT_mul_pT(sigp,_mmT_sub_pT(v_pp2,v_pp1)))); // v_tau += coe * ((v_fun - Htau) + tmp) if mask is true v_tau = _mmT_add_pT(v_tau,_mmT_mul_pT(coe,_mmT_add_pT(_mmT_sub_pT(v_fun,Htau),tmp))); #if USE_AVX512 // mask if v_change != 1.0 __mmaskT mask = _mm512_cmp_pT_mask(v_change,v_1,_CMP_NEQ_OQ); // set 1 if mask is true v_tau = _mm512_mask_blend_pT(mask,v_tau,v_1); #elif USE_AVX __m256d mask = _mm256_cmp_pT(v_change, v_1,_CMP_NEQ_OQ); v_tau = _mm256_blendv_pd(v_tau, v_1,mask); #endif #elif USE_ARM_SVE __mT COEF = svdup_f64(1.0); __mT v_1 = svdup_f64(1.0); __mT v_half = svdup_f64(0.5); __mT v_m2 = svdup_f64(-2.0); __mT coe_max = svdup_f64(1e19); // sigr = COEF * sqrt(v_fac_a)*v_T0v; __mT sigr = svmul_f64_z(pg, svmul_f64_z(pg,COEF,svsqrt_f64_z(pg,v_fac_a)), v_T0v); // sigt = COEF * sqrt(v_fac_b)*v_T0v; __mT sigt = svmul_f64_z(pg, svmul_f64_z(pg,COEF,svsqrt_f64_z(pg,v_fac_b)), v_T0v); // sigp = COEF * sqrt(v_fac_c)*v_T0v; __mT sigp = svmul_f64_z(pg, svmul_f64_z(pg,COEF,svsqrt_f64_z(pg,v_fac_c)), v_T0v); // coe = 1.0 / (sigr/D + sigt/D + sigz/D); __mT coe = svdiv_f64_z(pg, v_1, svadd_f64_z(pg, svadd_f64_z(pg, svmul_f64_z(pg, sigr, DR_inv), svmul_f64_z(pg, sigt, DT_inv)), svmul_f64_z(pg, sigp, DP_inv))); // coe becomes inf as sig* goes to 0 // if coe > 1e19, set coe = 1e19 coe = svmin_f64_z(pg, coe, coe_max); // Htau = v_fac_a * square(v_T0r * v_tau + v_T0v * (v_pr1 + v_pr2)*0.5); __mT Htau = svmul_f64_z(pg, v_fac_a, my_square_v(pg, svadd_f64_z(pg, \ svmul_f64_z(pg, v_T0r, v_tau), \ svmul_f64_z(pg, v_T0v, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pr1, v_pr2))))\ )\ ); // Htau += v_fac_b * square(v_T0t * v_tau + v_T0v * (v_pt1 + v_pt2)*0.5)) Htau = svadd_f64_z(pg, Htau, svmul_f64_z(pg, v_fac_b, my_square_v(pg, svadd_f64_z(pg, \ svmul_f64_z(pg, v_T0t, v_tau), \ svmul_f64_z(pg, v_T0v, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pt1, v_pt2))))\ )\ )\ ); // Htau += v_fac_c * square(v_T0p * v_tau + v_T0v * (v_pp1 + v_pp2)*0.5)) Htau = svadd_f64_z(pg, Htau, svmul_f64_z(pg, v_fac_c, my_square_v(pg, svadd_f64_z(pg, \ svmul_f64_z(pg, v_T0p, v_tau), \ svmul_f64_z(pg, v_T0v, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pp1, v_pp2))))\ )\ )\ ); // tmp1 = ( v_T0t * v_tau + v_T0v * (v_pt1 + v_pt2)*0.5) __mT tmp1 = svadd_f64_z(pg, \ svmul_f64_z(pg, v_T0t, v_tau), \ svmul_f64_z(pg, v_T0v, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pt1, v_pt2)))\ ); // tmp2 = ( v_T0p * v_tau + v_T0v * (v_pp1 + v_pp2)*0.5)) __mT tmp2 = svadd_f64_z(pg, \ svmul_f64_z(pg, v_T0p, v_tau), \ svmul_f64_z(pg, v_T0v, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pp1, v_pp2)))\ ); // tmp3 = -2.0 * v_fac_f * tmp1 * tmp2; __mT tmp3 = svmul_f64_z(pg, v_m2, svmul_f64_z(pg, v_fac_f, svmul_f64_z(pg, tmp1, tmp2))); // Htau = sqrt(Htau + tmp3); // # becamse nan as Htau - tmp3 goes to 0 Htau = svsqrt_f64_z(pg, svadd_f64_z(pg, Htau, tmp3)); // tmp = (sigr*(v_pr2 - v_pr1) + sigt*(v_pt2 - v_pt1) + sigz*(v_pp2 - v_pp1))*0.5 __mT tmp = svmul_f64_z(pg, v_half, svadd_f64_z(pg, \ svadd_f64_z(pg, \ svmul_f64_z(pg, sigr, svsub_f64_z(pg, v_pr2, v_pr1)), \ svmul_f64_z(pg, sigt, svsub_f64_z(pg, v_pt2, v_pt1))\ ), \ svmul_f64_z(pg, sigp, svsub_f64_z(pg, v_pp2, v_pp1))\ )\ ); // v_tau += coe * ((v_fun - Htau) + tmp) if mask is true v_tau = svadd_f64_z(pg, v_tau, svmul_f64_z(pg, coe, svadd_f64_z(pg, svsub_f64_z(pg, v_fun, Htau), tmp))); // mask = v_change != 1.0 svbool_t mask = svcmpne_f64(pg, v_change, v_1); // v_tau = v_1 if mask is true (!= 1.0) v_tau = svsel_f64(mask, v_1, v_tau); #endif } inline void vect_stencil_1st_3rd_apre_simd_tele( #if USE_ARM_SVE svbool_t const& pg, #endif __mT& v_tau, __mT const& v_fac_a,__mT const& v_fac_b, __mT const& v_fac_c, __mT const& v_fac_f, __mT const& v_fun, __mT const& v_change, __mT const& v_pp1, __mT const& v_pp2, __mT const& v_pt1, __mT const& v_pt2, __mT const& v_pr1, __mT const& v_pr2, __mT const& DP_inv, __mT const& DT_inv, __mT const& DR_inv){ #if USE_AVX512 || USE_AVX // sigr = COEF * sqrt(v_fac_a); __mT sigr = _mmT_mul_pT(COEF_TELE,_mmT_sqrt_pT(v_fac_a)); // sigt = COEF * sqrt(v_fac_b); __mT sigt = _mmT_mul_pT(COEF_TELE,_mmT_sqrt_pT(v_fac_b)); // sigp = COEF * sqrt(v_fac_c); __mT sigp = _mmT_mul_pT(COEF_TELE,_mmT_sqrt_pT(v_fac_c)); // coe = 1.0 / (sigr/D + sigt/D + sigz/D); __mT coe = _mmT_div_pT(v_1,_mmT_add_pT(_mmT_add_pT(_mmT_mul_pT(sigr,DR_inv),_mmT_mul_pT(sigt,DT_inv)),_mmT_mul_pT(sigp,DP_inv))); // coe becomes inf as sig* goes to 0 // if coe > 1e19, set coe = 1e19 coe = _mmT_min_pT(coe,coe_max); // Htau = v_fac_a * square((v_pr1 + v_pr2)*0.5); __mT Htau = _mmT_mul_pT(v_fac_a,my_square_v(_mmT_mul_pT(_mmT_add_pT(v_pr1,v_pr2),v_half))); // Htau += v_fac_b * square((v_pt1 + v_pt2)*0.5)) Htau = _mmT_add_pT(Htau,_mmT_mul_pT(v_fac_b,my_square_v(_mmT_mul_pT(_mmT_add_pT(v_pt1,v_pt2), v_half)))); // Htau += v_fac_c * square((v_pp1 + v_pp2)*0.5)) Htau = _mmT_add_pT(Htau,_mmT_mul_pT(v_fac_c,my_square_v(_mmT_mul_pT(_mmT_add_pT(v_pp1,v_pp2), v_half)))); // tmp1 = (v_pt1 + v_pt2)*0.5 __mT tmp1 = _mmT_mul_pT(_mmT_add_pT(v_pt1,v_pt2), v_half); // tmp2 = (v_pp1 + v_pp2)*0.5 __mT tmp2 = _mmT_mul_pT(_mmT_add_pT(v_pp1,v_pp2), v_half); // tmp3 = -2.0 * v_fac_f * tmp1 * tmp2; __mT tmp3 = _mmT_mul_pT(v_m2,_mmT_mul_pT(v_fac_f,_mmT_mul_pT(tmp1,tmp2))); // Htau = sqrt(Htau + tmp3); // # becamse nan as Htau - tmp3 goes to 0 Htau = _mmT_sqrt_pT(_mmT_add_pT(Htau,tmp3)); // tmp = (sigr*(v_pr2 - v_pr1) + sigt*(v_pt2 - v_pt1) + sigz*(v_pp2 - v_pp1))*0.5 __mT tmp = _mmT_mul_pT(v_half,_mmT_add_pT(_mmT_add_pT(_mmT_mul_pT(sigr,_mmT_sub_pT(v_pr2,v_pr1)),_mmT_mul_pT(sigt,_mmT_sub_pT(v_pt2,v_pt1))),_mmT_mul_pT(sigp,_mmT_sub_pT(v_pp2,v_pp1)))); // v_tau += coe * ((v_fun - Htau) + tmp) if mask is true v_tau = _mmT_add_pT(v_tau,_mmT_mul_pT(coe,_mmT_add_pT(_mmT_sub_pT(v_fun,Htau),tmp))); #if USE_AVX512 // mask if v_change != 1.0 __mmaskT mask = _mm512_cmp_pT_mask(v_change,v_1,_CMP_NEQ_OQ); // set 1 if mask is true v_tau = _mm512_mask_blend_pT(mask,v_tau,v_1); #elif USE_AVX __m256d mask = _mm256_cmp_pT(v_change, v_1,_CMP_NEQ_OQ); v_tau = _mm256_blendv_pd(v_tau, v_1,mask); #endif #elif USE_ARM_SVE __mT COEF_TELE = svdup_f64(1.05); __mT v_1 = svdup_f64(1.0); __mT v_half = svdup_f64(0.5); __mT v_m2 = svdup_f64(-2.0); __mT coe_max = svdup_f64(1e19); // sigr = COEF * sqrt(v_fac_a); __mT sigr = svmul_f64_z(pg,COEF_TELE,svsqrt_f64_z(pg,v_fac_a)); // sigt = COEF * sqrt(v_fac_b); __mT sigt = svmul_f64_z(pg,COEF_TELE,svsqrt_f64_z(pg,v_fac_b)); // sigp = COEF * sqrt(v_fac_c); __mT sigp = svmul_f64_z(pg,COEF_TELE,svsqrt_f64_z(pg,v_fac_c)); // coe = 1.0 / (sigr/D + sigt/D + sigz/D); __mT coe = svdiv_f64_z(pg, v_1, svadd_f64_z(pg, svadd_f64_z(pg, svmul_f64_z(pg, sigr, DR_inv), svmul_f64_z(pg, sigt, DT_inv)), svmul_f64_z(pg, sigp, DP_inv))); // coe becomes inf as sig* goes to 0 // if coe > 1e19, set coe = 1e19 coe = svmin_f64_z(pg, coe, coe_max); // Htau = v_fac_a * square((v_pr1 + v_pr2)*0.5); __mT Htau = svmul_f64_z(pg, v_fac_a, my_square_v(pg, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pr1, v_pr2)))); // Htau += v_fac_b * square((v_pt1 + v_pt2)*0.5)) Htau = svadd_f64_z(pg, Htau, svmul_f64_z(pg, v_fac_b, my_square_v(pg, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pt1, v_pt2))))); // Htau += v_fac_c * square((v_pp1 + v_pp2)*0.5)) Htau = svadd_f64_z(pg, Htau, svmul_f64_z(pg, v_fac_c, my_square_v(pg, svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pp1, v_pp2))))); // tmp1 = (v_pt1 + v_pt2)*0.5 __mT tmp1 = svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pt1, v_pt2)); // tmp2 = (v_pp1 + v_pp2)*0.5 __mT tmp2 = svmul_f64_z(pg, v_half, svadd_f64_z(pg, v_pp1, v_pp2)); // tmp3 = -2.0 * v_fac_f * tmp1 * tmp2; __mT tmp3 = svmul_f64_z(pg, v_m2, svmul_f64_z(pg, v_fac_f, svmul_f64_z(pg, tmp1, tmp2))); // Htau = sqrt(Htau + tmp3); // # becamse nan as Htau - tmp3 goes to 0 Htau = svsqrt_f64_z(pg, svadd_f64_z(pg, Htau, tmp3)); // tmp = (sigr*(v_pr2 - v_pr1) + sigt*(v_pt2 - v_pt1) + sigz*(v_pp2 - v_pp1))*0.5 __mT tmp = svmul_f64_z(pg, v_half, svadd_f64_z(pg, \ svadd_f64_z(pg, \ svmul_f64_z(pg, sigr, svsub_f64_z(pg, v_pr2, v_pr1)), \ svmul_f64_z(pg, sigt, svsub_f64_z(pg, v_pt2, v_pt1))\ ), \ svmul_f64_z(pg, sigp, svsub_f64_z(pg, v_pp2, v_pp1))\ )\ ); // v_tau += coe * ((v_fun - Htau) + tmp) if mask is true v_tau = svadd_f64_z(pg, v_tau, svmul_f64_z(pg, coe, svadd_f64_z(pg, svsub_f64_z(pg, v_fun, Htau), tmp))); // mask = v_change != 1.0 svbool_t mask = svcmpne_f64(pg, v_change, v_1); // v_tau = v_1 if mask is true (!= 1.0) v_tau = svsel_f64(mask, v_1, v_tau); #endif } // function for calculating the values on boundaries inline void calculate_boundary_nodes_tele_simd( #if USE_ARM_SVE svbool_t const& pg, #endif __mT const& v_iip, __mT const& v_jjt, __mT const& v_kkr, __mT & v_c__, __mT const& v_p__, __mT const& v_m__, __mT const& v__p_, __mT const& v__m_, __mT const& v___p, __mT const& v___m, __mT const& v_pp____, __mT const& v_mm____, __mT const& v___pp__, __mT const& v___mm__, __mT const& v_____pp, __mT const& v_____mm, __mT const& v_change, int const& loc_I, int const& loc_J, int const& loc_K ){ #if USE_AVX512 // mask for bottom boundary (v_kkr == 0 && v_change == 1.0) __mmaskT mask_bot = _kand_maskT(\ _mm512_cmp_pT_mask(v_kkr, v_0, _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // mask for top boundary (v_kkr == loc_K-1 && v_change == 1.0) __mmaskT mask_top = _kand_maskT(\ _mm512_cmp_pT_mask(v_kkr, _mmT_set1_pT(loc_K-1), _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // mask for south boundary (v_jjt == 0 && v_change == 1.0) __mmaskT mask_south = _kand_maskT(\ _mm512_cmp_pT_mask(v_jjt, v_0, _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // mask for north boundary (v_jjt == loc_J-1 && v_change == 1.0) __mmaskT mask_north = _kand_maskT(\ _mm512_cmp_pT_mask(v_jjt, _mmT_set1_pT(loc_J-1), _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // mask for west boundary (v_iip == 0 && v_change == 1.0) __mmaskT mask_west = _kand_maskT(\ _mm512_cmp_pT_mask(v_iip, v_0, _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // mask for east boundary (v_iip == loc_I-1 && v_change == 1.0) __mmaskT mask_east = _kand_maskT(\ _mm512_cmp_pT_mask(v_iip, _mmT_set1_pT(loc_I-1), _CMP_EQ_OQ), \ _mm512_cmp_pT_mask(v_change, v_1, _CMP_EQ_OQ)); // if mask_bot, v_c__ = max(2*v___p - v_____pp, v_____pp) v_c__ = _mm512_mask_blend_pT(mask_bot, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v___p), v_____pp), v_____pp)); // if mask_top, v_c__ = max(2*v___m - v_____mm, v_____mm) v_c__ = _mm512_mask_blend_pT(mask_top, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v___m), v_____mm), v_____mm)); // if mask_south, v_c__ = max(2*v__p_ - v___pp__, v___pp__) v_c__ = _mm512_mask_blend_pT(mask_south, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v__p_), v___pp__), v___pp__)); // if mask_north, v_c__ = max(2*v__m_ - v___mm__, v___mm__) v_c__ = _mm512_mask_blend_pT(mask_north, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v__m_), v___mm__), v___mm__)); // if mask_west, v_c__ = max(2*v_p__ - v_pp____, v_pp____) v_c__ = _mm512_mask_blend_pT(mask_west, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v_p__), v_pp____), v_pp____)); // if mask_east, v_c__ = max(2*v_m__ - v_mm____, v_mm____) v_c__ = _mm512_mask_blend_pT(mask_east, v_c__, _mmT_max_pT(_mmT_sub_pT(_mmT_mul_pT(v_2, v_m__), v_mm____), v_mm____)); #elif USE_AVX // mask for bottom boundary (v_kkr == 0 && v_change == 1.0) __m256d mask_bot = _mm256_and_pd(\ _mm256_cmp_pT(v_kkr, v_0, _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // mask for top boundary (v_kkr == loc_K-1 && v_change == 1.0) __m256d mask_top = _mm256_and_pd(\ _mm256_cmp_pT(v_kkr, _mmT_set1_pT(loc_K-1), _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // mask for south boundary (v_jjt == 0 && v_change == 1.0) __m256d mask_south = _mm256_and_pd(\ _mm256_cmp_pT(v_jjt, v_0, _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // mask for north boundary (v_jjt == loc_J-1 && v_change == 1.0) __m256d mask_north = _mm256_and_pd(\ _mm256_cmp_pT(v_jjt, _mmT_set1_pT(loc_J-1), _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // mask for west boundary (v_iip == 0 && v_change == 1.0) __m256d mask_west = _mm256_and_pd(\ _mm256_cmp_pT(v_iip, v_0, _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // mask for east boundary (v_iip == loc_I-1 && v_change == 1.0) __m256d mask_east = _mm256_and_pd(\ _mm256_cmp_pT(v_iip, _mmT_set1_pT(loc_I-1), _CMP_EQ_OQ), \ _mm256_cmp_pT(v_change, v_1, _CMP_EQ_OQ)); // if mask_bot, v_c__ = max(2*v___p - v_____pp, v_____pp) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v___p), v_____pp), v_____pp), mask_bot); // if mask_top, v_c__ = max(2*v___m - v_____mm, v_____mm) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v___m), v_____mm), v_____mm), mask_top); // if mask_south, v_c__ = max(2*v__p_ - v___pp__, v___pp__) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v__p_), v___pp__), v___pp__), mask_south); // if mask_north, v_c__ = max(2*v__m_ - v___mm__, v___mm__) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v__m_), v___mm__), v___mm__), mask_north); // if mask_west, v_c__ = max(2*v_p__ - v_pp____, v_pp____) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v_p__), v_pp____), v_pp____), mask_west); // if mask_east, v_c__ = max(2*v_m__ - v_mm____, v_mm____) v_c__ = _mm256_blendv_pd(v_c__, _mm256_max_pd(_mm256_sub_pd(_mm256_mul_pd(v_2, v_m__), v_mm____), v_mm____), mask_east); #elif USE_ARM_SVE svfloat64_t v_0 = svdup_f64(0.0); svfloat64_t v_1 = svdup_f64(1.0); svfloat64_t v_2 = svdup_f64(2.0); svfloat64_t v_loc_I_minus_1 = svdup_f64(loc_I-1); svfloat64_t v_loc_J_minus_1 = svdup_f64(loc_J-1); svfloat64_t v_loc_K_minus_1 = svdup_f64(loc_K-1); // mask for bottom boundary svbool_t mask_bot = svand_b_z(pg, \ svcmpeq_f64(pg, v_kkr, v_0), \ svcmpeq_f64(pg, v_change, v_1)); // mask for top boundary svbool_t mask_top = svand_b_z(pg, \ svcmpeq_f64(pg, v_kkr, v_loc_K_minus_1), \ svcmpeq_f64(pg, v_change, v_1)); // mask for south boundary svbool_t mask_south = svand_b_z(pg, \ svcmpeq_f64(pg, v_jjt, v_0), \ svcmpeq_f64(pg, v_change, v_1)); // mask for north boundary svbool_t mask_north = svand_b_z(pg, \ svcmpeq_f64(pg, v_jjt, v_loc_J_minus_1), \ svcmpeq_f64(pg, v_change, v_1)); // mask for west boundary svbool_t mask_west = svand_b_z(pg, \ svcmpeq_f64(pg, v_iip, v_0), \ svcmpeq_f64(pg, v_change, v_1)); // mask for east boundary svbool_t mask_east = svand_b_z(pg, \ svcmpeq_f64(pg, v_iip, v_loc_I_minus_1), \ svcmpeq_f64(pg, v_change, v_1)); // if mask_bot, v_c__ = max(2*v___p - v_____pp, v_____pp) v_c__ = svsel_f64(mask_bot, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v___p), v_____pp), v_____pp), \ v_c__); // if mask_top, v_c__ = max(2*v___m - v_____mm, v_____mm) v_c__ = svsel_f64(mask_top, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v___m), v_____mm), v_____mm), \ v_c__); // if mask_south, v_c__ = max(2*v__p_ - v___pp__, v___pp__) v_c__ = svsel_f64(mask_south, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v__p_), v___pp__), v___pp__), \ v_c__); // if mask_north, v_c__ = max(2*v__m_ - v___mm__, v___mm__) v_c__ = svsel_f64(mask_north, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v__m_), v___mm__), v___mm__), \ v_c__); // if mask_west, v_c__ = max(2*v_p__ - v_pp____, v_pp____) v_c__ = svsel_f64(mask_west, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v_p__), v_pp____), v_pp____), \ v_c__); // if mask_east, v_c__ = max(2*v_m__ - v_mm____, v_mm____) v_c__ = svsel_f64(mask_east, \ svmax_f64_z(pg, svsub_f64_z(pg, svmul_f64_z(pg, v_2, v_m__), v_mm____), v_mm____), \ v_c__); #endif } #if USE_AVX512 || USE_AVX inline __mT load_mem_gen_to_mTd(CUSTOMREAL* a, int* ijk){ CUSTOMREAL dump_[NSIMD]; for (int i=0; i