add back conditionaltype (#1655)

This commit is contained in:
Alex Barron 2024-12-06 11:12:01 -08:00 committed by GitHub
parent bc2a29f033
commit 95c4a2e3af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 3 deletions

View File

@ -854,15 +854,17 @@ METAL_FUNC void qvm_impl(
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int tn = 32 / pack_factor; constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE; constexpr int block_size = SIMD_SIZE;
const device uint8_t* ws = (const device uint8_t*)w; using W_T =
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
const device W_T* ws = (const device W_T*)w;
typedef float U; typedef float U;
typedef struct { typedef struct {
uint8_t wi[tn * bytes_per_pack]; W_T wi[tn * bytes_per_pack];
} vec_w; } vec_w;
thread vec_w w_local; thread vec_w w_local;

View File

@ -421,3 +421,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
return complex64_t( return complex64_t(
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
} }
// std::conditional is not included with Metal
template <bool condition, typename T, typename U>
struct ConditionalType {
using type = U;
};
template <typename T, typename U>
struct ConditionalType<true, T, U> {
using type = T;
};